Skip to content

Commit

Permalink
Fixed memory leaks from OnnxTransformer (#5518)
Browse files Browse the repository at this point in the history
* Fixed memory leak from OnnxTransformer and related x86 build fixes

* Reverting x86 build related fixes to focus only on the memory leaks

* Updated docs

* Reverted OnnxRuntimeOutputCatcher to private class

* Addressed code review comments

* Refactored OnnxTransform back to using MapperBase based on code review comments
  • Loading branch information
harishsk committed Dec 2, 2020
1 parent f151a4a commit 9a8c46c
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ public static void Example()
//Create the pipeline using onnx file.
var onnxModelPath = "your_path_to_sample_onnx_conversion_1.onnx";
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(trainTestOriginalData.TrainSet);
//Make sure to either use the 'using' clause or explicitly dispose the returned onnxTransformer to prevent memory leaks
using var onnxTransformer = onnxEstimator.Fit(trainTestOriginalData.TrainSet);

//Inference the testset
var output = transformer.Transform(trainTestOriginalData.TestSet);
Expand Down
10 changes: 5 additions & 5 deletions src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ private protected abstract class MapperBase : IRowMapper
{
protected readonly IHost Host;
protected readonly DataViewSchema InputSchema;
private readonly Lazy<DataViewSchema.DetachedColumn[]> _outputColumns;
protected readonly Lazy<DataViewSchema.DetachedColumn[]> OutputColumns;
private readonly RowToRowTransformerBase _parent;

protected MapperBase(IHost host, DataViewSchema inputSchema, RowToRowTransformerBase parent)
Expand All @@ -68,21 +68,21 @@ protected MapperBase(IHost host, DataViewSchema inputSchema, RowToRowTransformer
Host = host;
InputSchema = inputSchema;
_parent = parent;
_outputColumns = new Lazy<DataViewSchema.DetachedColumn[]>(GetOutputColumnsCore);
OutputColumns = new Lazy<DataViewSchema.DetachedColumn[]>(GetOutputColumnsCore);
}

protected abstract DataViewSchema.DetachedColumn[] GetOutputColumnsCore();

DataViewSchema.DetachedColumn[] IRowMapper.GetOutputColumns() => _outputColumns.Value;
DataViewSchema.DetachedColumn[] IRowMapper.GetOutputColumns() => OutputColumns.Value;

Delegate[] IRowMapper.CreateGetters(DataViewRow input, Func<int, bool> activeOutput, out Action disposer)
public virtual Delegate[] CreateGetters(DataViewRow input, Func<int, bool> activeOutput, out Action disposer)
{
// REVIEW: it used to be that the mapper's input schema in the constructor was required to be reference-equal to the schema
// of the input row.
// It still has to be the same schema, but because we may make a transition from lazy to eager schema, the reference-equality
// is no longer always possible. So, we relax the assert as below.
Contracts.Assert(input.Schema == InputSchema);
int n = _outputColumns.Value.Length;
int n = OutputColumns.Value.Length;
var result = new Delegate[n];
var disposers = new Action[n];
for (int i = 0; i < n; i++)
Expand Down
74 changes: 56 additions & 18 deletions src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,10 @@ private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> a
private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);

protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
=> throw new NotImplementedException("This should never be called!");

private Delegate CreateGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, OnnxRuntimeOutputCacher outputCacher)
{
disposer = null;
Host.AssertValue(input);

var activeOutputColNames = _parent.Outputs.Where((x, i) => activeOutput(i)).ToArray();
Expand All @@ -495,26 +497,59 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
var elemRawType = vectorType.ItemType.RawType;
var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _inputColIndices, _inputOnnxTypes, _inputTensorShapes);
if (vectorType.ItemType is TextDataViewType)
return MakeStringTensorGetter(input, iinfo, srcNamedValueGetters, activeOutputColNames);
return MakeStringTensorGetter(input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCacher);
else
return Utils.MarshalInvoke(MakeTensorGetter<int>, elemRawType, input, iinfo, srcNamedValueGetters, activeOutputColNames);
return Utils.MarshalInvoke(MakeTensorGetter<int>, elemRawType, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCacher);
}
else
{
var type = _parent.Model.ModelInfo.OutputsInfo[_parent.MapDataViewColumnToOnnxOutputTensor(iinfo)].DataViewType.RawType;
var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _inputColIndices, _inputOnnxTypes, _inputTensorShapes);
return Utils.MarshalInvoke(MakeObjectGetter<int>, type, input, iinfo, srcNamedValueGetters, activeOutputColNames);
return Utils.MarshalInvoke(MakeObjectGetter<int>, type, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCacher);
}
}

public override Delegate[] CreateGetters(DataViewRow input, Func<int, bool> activeOutput, out Action disposer)
{
Contracts.Assert(input.Schema == InputSchema);

OnnxRuntimeOutputCacher outputCacher = new OnnxRuntimeOutputCacher();

int n = OutputColumns.Value.Length;
var result = new Delegate[n];
for (int i = 0; i < n; i++)
{
if (!activeOutput(i))
continue;
result[i] = CreateGetter(input, i, activeOutput, outputCacher);
}
disposer = () =>
{
outputCacher.Dispose();
};
return result;
}

private class OnnxRuntimeOutputCacher
private sealed class OnnxRuntimeOutputCacher : IDisposable
{
public long Position;
public Dictionary<string, NamedOnnxValue> Outputs;
public Dictionary<string, DisposableNamedOnnxValue> Outputs;
public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> OutputOnnxValues;

public OnnxRuntimeOutputCacher()
{
Position = -1;
Outputs = new Dictionary<string, NamedOnnxValue>();
Outputs = new Dictionary<string, DisposableNamedOnnxValue>();
}

private bool _isDisposed;

public void Dispose()
{
if (_isDisposed)
return;
OutputOnnxValues?.Dispose();
_isDisposed = true;
}
}

Expand All @@ -529,21 +564,22 @@ private void UpdateCacheIfNeeded(long position, INamedOnnxValueGetter[] srcNamed
inputNameOnnxValues.Add(srcNamedOnnxValueGetters[i].GetNamedOnnxValue());
}

var outputNamedOnnxValues = _parent.Model.Run(inputNameOnnxValues);
Contracts.Assert(outputNamedOnnxValues.Count > 0);
outputCache.OutputOnnxValues?.Dispose();
outputCache.OutputOnnxValues = _parent.Model.Run(inputNameOnnxValues);
Contracts.Assert(outputCache.OutputOnnxValues.Count > 0);

foreach (var outputNameOnnxValue in outputNamedOnnxValues)
foreach (var outputNameOnnxValue in outputCache.OutputOnnxValues)
{
outputCache.Outputs[outputNameOnnxValue.Name] = outputNameOnnxValue;
}
outputCache.Position = position;
}
}

private Delegate MakeTensorGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames)
private Delegate MakeTensorGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters,
string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher)
{
Host.AssertValue(input);
var outputCacher = new OnnxRuntimeOutputCacher();
ValueGetter<VBuffer<T>> valueGetter = (ref VBuffer<T> dst) =>
{
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher);
Expand All @@ -558,10 +594,11 @@ private Delegate MakeTensorGetter<T>(DataViewRow input, int iinfo, INamedOnnxVal
return valueGetter;
}

private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames)
private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters,
string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher)
{
Host.AssertValue(input);
var outputCacher = new OnnxRuntimeOutputCacher();

ValueGetter<VBuffer<ReadOnlyMemory<char>>> valueGetter = (ref VBuffer<ReadOnlyMemory<char>> dst) =>
{
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher);
Expand All @@ -580,14 +617,15 @@ private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnx
return valueGetter;
}

private Delegate MakeObjectGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames)
private Delegate MakeObjectGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters,
string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher)
{
Host.AssertValue(input);
var outputCache = new OnnxRuntimeOutputCacher();

ValueGetter<T> valueGetter = (ref T dst) =>
{
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCache);
var namedOnnxValue = outputCache.Outputs[_parent.Outputs[iinfo]];
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher);
var namedOnnxValue = outputCacher.Outputs[_parent.Outputs[iinfo]];
var trueValue = namedOnnxValue.AsEnumerable<NamedOnnxValue>().Select(value => value.AsDictionary<string, float>());
var caster = _parent.Model.ModelInfo.OutputsInfo[_parent.MapDataViewColumnToOnnxOutputTensor(iinfo)].Caster;
dst = (T)caster(namedOnnxValue);
Expand Down
71 changes: 40 additions & 31 deletions src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -198,40 +198,49 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu =
_session = new InferenceSession(modelFile);
}

// Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime
// doesn't expose full type information via its C# APIs.
ModelFile = modelFile;
var model = new OnnxCSharpToProtoWrapper.ModelProto();
using (var modelStream = File.OpenRead(modelFile))
using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, 10))
model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream);

// Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's.
var inputTypePool = new Dictionary<string, DataViewType>();
foreach (var valueInfo in model.Graph.Input)
inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);

var initializerTypePool = new Dictionary<string, DataViewType>();
foreach (var valueInfo in model.Graph.Initializer)
initializerTypePool[valueInfo.Name] = OnnxTypeParser.GetScalarDataViewType(valueInfo.DataType);

var outputTypePool = new Dictionary<string, DataViewType>();
// Build casters which maps NamedOnnxValue to .NET objects.
var casterPool = new Dictionary<string, Func<NamedOnnxValue, object>>();
foreach (var valueInfo in model.Graph.Output)
try
{
outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
casterPool[valueInfo.Name] = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType);
}
// Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime
// doesn't expose full type information via its C# APIs.
ModelFile = modelFile;
var model = new OnnxCSharpToProtoWrapper.ModelProto();
using (var modelStream = File.OpenRead(modelFile))
using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, 10))
model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream);

// Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's.
var inputTypePool = new Dictionary<string, DataViewType>();
foreach (var valueInfo in model.Graph.Input)
inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);

var initializerTypePool = new Dictionary<string, DataViewType>();
foreach (var valueInfo in model.Graph.Initializer)
initializerTypePool[valueInfo.Name] = OnnxTypeParser.GetScalarDataViewType(valueInfo.DataType);

var outputTypePool = new Dictionary<string, DataViewType>();
// Build casters which maps NamedOnnxValue to .NET objects.
var casterPool = new Dictionary<string, Func<NamedOnnxValue, object>>();
foreach (var valueInfo in model.Graph.Output)
{
outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
casterPool[valueInfo.Name] = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType);
}

var inputInfos = GetOnnxVariablesFromMetadata(_session.InputMetadata, shapeDictionary, inputTypePool, null);
var outputInfos = GetOnnxVariablesFromMetadata(_session.OutputMetadata, shapeDictionary, outputTypePool, casterPool);
var overrideableInitializers = GetOnnxVariablesFromMetadata(_session.OverridableInitializerMetadata, shapeDictionary, inputTypePool, null);
var inputInfos = GetOnnxVariablesFromMetadata(_session.InputMetadata, shapeDictionary, inputTypePool, null);
var outputInfos = GetOnnxVariablesFromMetadata(_session.OutputMetadata, shapeDictionary, outputTypePool, casterPool);
var overrideableInitializers = GetOnnxVariablesFromMetadata(_session.OverridableInitializerMetadata, shapeDictionary, inputTypePool, null);

// Create a view to the used ONNX model from ONNXRuntime's perspective.
ModelInfo = new OnnxModelInfo(inputInfos, outputInfos, overrideableInitializers);
// Create a view to the used ONNX model from ONNXRuntime's perspective.
ModelInfo = new OnnxModelInfo(inputInfos, outputInfos, overrideableInitializers);

Graph = model.Graph;
Graph = model.Graph;
}
catch
{
_session.Dispose();
_session = null;
throw;
}
}

private List<OnnxVariableInfo> GetOnnxVariablesFromMetadata(IReadOnlyDictionary<string, NodeMetadata> nodeMetadata,
Expand Down Expand Up @@ -350,7 +359,7 @@ public static OnnxModel CreateFromBytes(byte[] modelBytes, int? gpuDeviceId = nu
/// </summary>
/// <param name="inputNamedOnnxValues">The NamedOnnxValues to score.</param>
/// <returns>Resulting output NamedOnnxValues list.</returns>
public IReadOnlyCollection<NamedOnnxValue> Run(List<NamedOnnxValue> inputNamedOnnxValues)
public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> Run(List<NamedOnnxValue> inputNamedOnnxValues)
{
return _session.Run(inputNamedOnnxValues);
}
Expand Down
4 changes: 2 additions & 2 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ public void RemoveVariablesInPipelineTest()
.Append(mlContext.Transforms.NormalizeMinMax("Features"))
.Append(mlContext.BinaryClassification.Trainers.FastTree(labelColumnName: "Label", featureColumnName: "Features", numberOfLeaves: 2, numberOfTrees: 1, minimumExampleCountPerLeaf: 2));

var model = pipeline.Fit(data);
using var model = pipeline.Fit(data);
var transformedData = model.Transform(data);

var onnxConversionContext = new OnnxContextImpl(mlContext, "A Simple Pipeline", "ML.NET", "0", 0, "machinelearning.dotnet", OnnxVersion.Stable);
Expand Down Expand Up @@ -2029,7 +2029,7 @@ private void TestPipeline<TLastTransformer, TRow>(EstimatorChain<TLastTransforme
private void TestPipeline<TLastTransformer>(EstimatorChain<TLastTransformer> pipeline, IDataView dataView, string onnxFileName, ColumnComparison[] columnsToCompare, string onnxTxtName = null, string onnxTxtSubDir = null)
where TLastTransformer : class, ITransformer
{
var model = pipeline.Fit(dataView);
using var model = pipeline.Fit(dataView);
var transformedData = model.Transform(dataView);
var onnxModel = ML.Model.ConvertToOnnxProtobuf(model, dataView);

Expand Down

0 comments on commit 9a8c46c

Please sign in to comment.