Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed memory leaks from OnnxTransformer #5518

Merged
merged 6 commits into from
Dec 2, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 86 additions & 34 deletions src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,12 @@ public void Dispose()
_isDisposed = true;
}

private sealed class Mapper : MapperBase
private sealed class Mapper : IRowMapper
{
private readonly IHost _host;
private readonly DataViewSchema _inputSchema;
private readonly Lazy<DataViewSchema.DetachedColumn[]> _outputColumns;

private readonly OnnxTransformer _parent;
/// <summary>
/// <see cref="_inputColIndices"/>'s i-th element value tells the <see cref="IDataView"/> column index to
Expand All @@ -379,9 +383,11 @@ private sealed class Mapper : MapperBase
/// </summary>
private readonly Type[] _inputOnnxTypes;

public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema, parent)
public Mapper(OnnxTransformer parent, DataViewSchema inputSchema)
{
_host = Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper));
_inputSchema = inputSchema;
_outputColumns = new Lazy<DataViewSchema.DetachedColumn[]>(GetOutputColumnsCore);

_parent = parent;
_inputColIndices = new int[_parent.Inputs.Length];
Expand All @@ -401,15 +407,15 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :

var col = inputSchema.GetColumnOrNull(_parent.Inputs[i]);
if (!col.HasValue)
throw Host.ExceptSchemaMismatch(nameof(inputSchema),"input", _parent.Inputs[i]);
throw _host.ExceptSchemaMismatch(nameof(inputSchema),"input", _parent.Inputs[i]);

_inputColIndices[i] = col.Value.Index;

var type = inputSchema[_inputColIndices[i]].Type;
var vectorType = type as VectorDataViewType;

if (vectorType != null && vectorType.Size == 0)
throw Host.Except($"Variable length input columns not supported");
throw _host.Except($"Variable length input columns not supported");

var itemType = type.GetItemType();
var nodeItemType = inputNodeInfo.DataViewType.GetItemType();
Expand All @@ -421,7 +427,7 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
// This is done to support a corner case originated in NimbusML. For more info, see: https://github.com/microsoft/NimbusML/issues/426
var isKeyType = itemType is KeyDataViewType;
if (!isKeyType || itemType.RawType != nodeItemType.RawType)
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputNodeInfo.DataViewType.GetItemType().ToString(), type.ToString());
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputNodeInfo.DataViewType.GetItemType().ToString(), type.ToString());
}

// If the column is one dimension we make sure that the total size of the Onnx shape matches.
Expand All @@ -433,8 +439,9 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {String.Join(",", inputShape)}, but input data is of length {typeValueCount}.");
}
}
DataViewSchema.DetachedColumn[] IRowMapper.GetOutputColumns() => _outputColumns.Value;

protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
private DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
{
var stdSuffix = ".output";
var info = new DataViewSchema.DetachedColumn[_parent.Outputs.Length];
Expand Down Expand Up @@ -476,17 +483,16 @@ private void AddSlotNames(string columnName, DataViewSchema.Annotations.Builder
builder.AddSlotNames(count, getter);
}

private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
private Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
{
return col => Enumerable.Range(0, _parent.Outputs.Length).Any(i => activeOutput(i)) && _inputColIndices.Any(i => i == col);
}

private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);
private void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);

protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
private Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, OnnxRuntimeOutputCacher outputCacher)
{
disposer = null;
Host.AssertValue(input);
_host.AssertValue(input);

var activeOutputColNames = _parent.Outputs.Where((x, i) => activeOutput(i)).ToArray();

Expand All @@ -495,26 +501,65 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
var elemRawType = vectorType.ItemType.RawType;
var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _inputColIndices, _inputOnnxTypes, _inputTensorShapes);
if (vectorType.ItemType is TextDataViewType)
return MakeStringTensorGetter(input, iinfo, srcNamedValueGetters, activeOutputColNames);
return MakeStringTensorGetter(input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCacher);
else
return Utils.MarshalInvoke(MakeTensorGetter<int>, elemRawType, input, iinfo, srcNamedValueGetters, activeOutputColNames);
return Utils.MarshalInvoke(MakeTensorGetter<int>, elemRawType, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCacher);
}
else
{
var type = _parent.Model.ModelInfo.OutputsInfo[_parent.MapDataViewColumnToOnnxOutputTensor(iinfo)].DataViewType.RawType;
var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _inputColIndices, _inputOnnxTypes, _inputTensorShapes);
return Utils.MarshalInvoke(MakeObjectGetter<int>, type, input, iinfo, srcNamedValueGetters, activeOutputColNames);
return Utils.MarshalInvoke(MakeObjectGetter<int>, type, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCacher);
}
}

Delegate[] IRowMapper.CreateGetters(DataViewRow input, Func<int, bool> activeOutput, out Action disposer)
Copy link
Member

@antoniovs1029 antoniovs1029 Dec 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, isn't it possible to mark this method as virtual on the MapperBase class and simply override it here? although I guess implementing IRowMapper directly isn't very bad, it's assymetric with the pattern that all RowToRowTransformerBase children (such as onnxtransformer) use a mapper that derives from MapperBase. #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! I don't know why I didn't think of trying it. Let me try it and see. It will obviously a cleaner solution.


In reply to: 533779304 [](ancestors = 533779304)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like that would make it messier. MakeGetter is also declared in MapperBase as an abstract function and I will have to implement that. And the signature for that is not correct for what OnnxTransformer needs. I can have my version of CreateGetters not call MakeGetter, but then I still need to implement MakeGetter to satisfy the base class. That leaves it a bit inconsistent.


In reply to: 533781345 [](ancestors = 533781345,533779304)

Copy link
Member

@antoniovs1029 antoniovs1029 Dec 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I've tried to think of an alternative to this, but this does seem to be the best solution to avoid refactoring more things. So I guess it's ok to leave it this way. I guess this is why OnnxTransformer used to have a Cacher for each column, but as you mentioned, this meant no actual caching was done. 😕 #Resolved

Copy link
Member

@eerhardt eerhardt Dec 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can have my version of CreateGetters not call MakeGetter, but then I still need to implement MakeGetter to satisfy the base class.

You can just have your override of MakeGetter throw new NotSupportedException("this should never be called."); #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! I have implemented this and reverted back to inheriting from MapperBase


In reply to: 533817299 [](ancestors = 533817299)

{
Contracts.Assert(input.Schema == _inputSchema);

OnnxRuntimeOutputCacher outputCacher = new OnnxRuntimeOutputCacher();

int n = _outputColumns.Value.Length;
var result = new Delegate[n];
for (int i = 0; i < n; i++)
{
if (!activeOutput(i))
continue;
result[i] = MakeGetter(input, i, activeOutput, outputCacher);
}
disposer = () =>
{
outputCacher.Dispose();
};
return result;
}

private class OnnxRuntimeOutputCacher
internal class OnnxRuntimeOutputCacher : IDisposable
Copy link
Member

@antoniovs1029 antoniovs1029 Dec 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering, why did this needed to change from private to internal? #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. That is a left over from a previous experiment. I will change it back.


In reply to: 533789998 [](ancestors = 533789998)

{
public long Position;
public Dictionary<string, NamedOnnxValue> Outputs;
public Dictionary<string, DisposableNamedOnnxValue> Outputs;
public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> OutputOnnxValues;

public OnnxRuntimeOutputCacher()
{
Position = -1;
Outputs = new Dictionary<string, NamedOnnxValue>();
Outputs = new Dictionary<string, DisposableNamedOnnxValue>();
}

private bool _isDisposed;

protected virtual void Dispose(bool disposing)
{
if (_isDisposed)
return;
OutputOnnxValues?.Dispose();
_isDisposed = true;
}

public void Dispose()
{
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
Copy link
Member

@eerhardt eerhardt Dec 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
protected virtual void Dispose(bool disposing)
{
if (_isDisposed)
return;
OutputOnnxValues?.Dispose();
_isDisposed = true;
}
public void Dispose()
{
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
public void Dispose()
{
if (_isDisposed)
return;
OutputOnnxValues?.Dispose();
_isDisposed = true;
}

This is a private class that we never derive from. So you can really simply the pattern here. Along with sealing the class above, you don't need to make a protected virtual void Dispose(bool) method. #Resolved

}

Expand All @@ -529,46 +574,47 @@ private void UpdateCacheIfNeeded(long position, INamedOnnxValueGetter[] srcNamed
inputNameOnnxValues.Add(srcNamedOnnxValueGetters[i].GetNamedOnnxValue());
}

var outputNamedOnnxValues = _parent.Model.Run(inputNameOnnxValues);
Contracts.Assert(outputNamedOnnxValues.Count > 0);
outputCache.OutputOnnxValues?.Dispose();
outputCache.OutputOnnxValues = _parent.Model.Run(inputNameOnnxValues);
Contracts.Assert(outputCache.OutputOnnxValues.Count > 0);

foreach (var outputNameOnnxValue in outputNamedOnnxValues)
foreach (var outputNameOnnxValue in outputCache.OutputOnnxValues)
{
outputCache.Outputs[outputNameOnnxValue.Name] = outputNameOnnxValue;
}
outputCache.Position = position;
}
}

private Delegate MakeTensorGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames)
private Delegate MakeTensorGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters,
string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher)
{
Host.AssertValue(input);
var outputCacher = new OnnxRuntimeOutputCacher();
_host.AssertValue(input);
ValueGetter<VBuffer<T>> valueGetter = (ref VBuffer<T> dst) =>
{
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher);
var namedOnnxValue = outputCacher.Outputs[_parent.Outputs[iinfo]];
var tensor = namedOnnxValue.AsTensor<T>() as Microsoft.ML.OnnxRuntime.Tensors.DenseTensor<T>;
if (tensor == null)
throw Host.Except($"Output column {namedOnnxValue.Name} doesn't contain a DenseTensor of expected type {typeof(T)}");
throw _host.Except($"Output column {namedOnnxValue.Name} doesn't contain a DenseTensor of expected type {typeof(T)}");
var editor = VBufferEditor.Create(ref dst, (int)tensor.Length);
tensor.Buffer.Span.CopyTo(editor.Values);
dst = editor.Commit();
};
return valueGetter;
}

private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames)
private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters,
string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher)
{
Host.AssertValue(input);
var outputCacher = new OnnxRuntimeOutputCacher();
_host.AssertValue(input);
ValueGetter<VBuffer<ReadOnlyMemory<char>>> valueGetter = (ref VBuffer<ReadOnlyMemory<char>> dst) =>
{
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher);
var namedOnnxValue = outputCacher.Outputs[_parent.Outputs[iinfo]];
var tensor = namedOnnxValue.AsTensor<string>() as Microsoft.ML.OnnxRuntime.Tensors.DenseTensor<string>;
if (tensor == null)
throw Host.Except($"Output column {namedOnnxValue.Name} doesn't contain a DenseTensor of expected type {typeof(string)}");
throw _host.Except($"Output column {namedOnnxValue.Name} doesn't contain a DenseTensor of expected type {typeof(string)}");

// Create VBufferEditor to fill "dst" with the values in "denseTensor".
var editor = VBufferEditor.Create(ref dst, (int)tensor.Length);
Expand All @@ -580,14 +626,14 @@ private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnx
return valueGetter;
}

private Delegate MakeObjectGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames)
private Delegate MakeObjectGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters,
string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher)
{
Host.AssertValue(input);
var outputCache = new OnnxRuntimeOutputCacher();
_host.AssertValue(input);
ValueGetter<T> valueGetter = (ref T dst) =>
{
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCache);
var namedOnnxValue = outputCache.Outputs[_parent.Outputs[iinfo]];
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher);
var namedOnnxValue = outputCacher.Outputs[_parent.Outputs[iinfo]];
var trueValue = namedOnnxValue.AsEnumerable<NamedOnnxValue>().Select(value => value.AsDictionary<string, float>());
var caster = _parent.Model.ModelInfo.OutputsInfo[_parent.MapDataViewColumnToOnnxOutputTensor(iinfo)].Caster;
dst = (T)caster(namedOnnxValue);
Expand Down Expand Up @@ -664,6 +710,12 @@ private static INamedOnnxValueGetter CreateNamedOnnxValueGetterVecCore<T>(DataVi
return new NamedOnnxValueGetterVec<T>(input, colIndex, onnxShape);
}

void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);

Func<int, bool> IRowMapper.GetDependencies(Func<int, bool> activeOutput) => GetDependenciesCore(activeOutput);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no need for these. Instead you can just rename SaveModel to Save and make it public (the class is private, so it isn't really "public").

The same for GetDependencies.


public ITransformer GetTransformer() => _parent;

/// <summary>
/// Common function for wrapping ML.NET getter as a NamedOnnxValue getter.
/// </summary>
Expand Down
Loading