Skip to content

Commit

Permalink
Fixed memory leak from OnnxTransformer and related x86 build fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
harishsk committed Dec 1, 2020
1 parent a819c70 commit 4b74271
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 68 deletions.
4 changes: 3 additions & 1 deletion Directory.Build.targets
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
<TargetArchitecture Condition="'$(Platform)' == ''">x64</TargetArchitecture>
<NativeTargetArchitecture Condition="'$(NativeTargetArchitecture)' == ''">$(TargetArchitecture)</NativeTargetArchitecture>
<BinDir Condition="'$(BinDir)'==''">$([MSBuild]::NormalizeDirectory('$(RepoRoot)', 'artifacts', 'bin'))</BinDir>
<NativeOutputPath>$(BinDir)Native\$(NativeTargetArchitecture).$(Configuration)\</NativeOutputPath>
<NativeOutputConfig Condition="$(Configuration.Contains('Debug'))">Debug</NativeOutputConfig>
<NativeOutputConfig Condition="$(Configuration.Contains('Release'))">Release</NativeOutputConfig>
<NativeOutputPath>$(BinDir)Native\$(NativeTargetArchitecture).$(NativeOutputConfig)\</NativeOutputPath>

<Platform Condition="'$(Platform)'==''">AnyCPU</Platform>
<PlatformConfig>$(Platform).$(Configuration)</PlatformConfig>
Expand Down
120 changes: 86 additions & 34 deletions src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,12 @@ public void Dispose()
_isDisposed = true;
}

private sealed class Mapper : MapperBase
private sealed class Mapper : IRowMapper
{
private readonly IHost _host;
private readonly DataViewSchema _inputSchema;
private readonly Lazy<DataViewSchema.DetachedColumn[]> _outputColumns;

private readonly OnnxTransformer _parent;
/// <summary>
/// <see cref="_inputColIndices"/>'s i-th element value tells the <see cref="IDataView"/> column index to
Expand All @@ -379,9 +383,11 @@ private sealed class Mapper : MapperBase
/// </summary>
private readonly Type[] _inputOnnxTypes;

public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema, parent)
public Mapper(OnnxTransformer parent, DataViewSchema inputSchema)
{
_host = Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper));
_inputSchema = inputSchema;
_outputColumns = new Lazy<DataViewSchema.DetachedColumn[]>(GetOutputColumnsCore);

_parent = parent;
_inputColIndices = new int[_parent.Inputs.Length];
Expand All @@ -401,15 +407,15 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :

var col = inputSchema.GetColumnOrNull(_parent.Inputs[i]);
if (!col.HasValue)
throw Host.ExceptSchemaMismatch(nameof(inputSchema),"input", _parent.Inputs[i]);
throw _host.ExceptSchemaMismatch(nameof(inputSchema),"input", _parent.Inputs[i]);

_inputColIndices[i] = col.Value.Index;

var type = inputSchema[_inputColIndices[i]].Type;
var vectorType = type as VectorDataViewType;

if (vectorType != null && vectorType.Size == 0)
throw Host.Except($"Variable length input columns not supported");
throw _host.Except($"Variable length input columns not supported");

var itemType = type.GetItemType();
var nodeItemType = inputNodeInfo.DataViewType.GetItemType();
Expand All @@ -421,7 +427,7 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
// This is done to support a corner case originated in NimbusML. For more info, see: https://github.com/microsoft/NimbusML/issues/426
var isKeyType = itemType is KeyDataViewType;
if (!isKeyType || itemType.RawType != nodeItemType.RawType)
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputNodeInfo.DataViewType.GetItemType().ToString(), type.ToString());
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputNodeInfo.DataViewType.GetItemType().ToString(), type.ToString());
}

// If the column is one dimension we make sure that the total size of the Onnx shape matches.
Expand All @@ -433,8 +439,9 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {String.Join(",", inputShape)}, but input data is of length {typeValueCount}.");
}
}
DataViewSchema.DetachedColumn[] IRowMapper.GetOutputColumns() => _outputColumns.Value;

protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
private DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
{
var stdSuffix = ".output";
var info = new DataViewSchema.DetachedColumn[_parent.Outputs.Length];
Expand Down Expand Up @@ -476,17 +483,16 @@ private void AddSlotNames(string columnName, DataViewSchema.Annotations.Builder
builder.AddSlotNames(count, getter);
}

private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
private Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
{
return col => Enumerable.Range(0, _parent.Outputs.Length).Any(i => activeOutput(i)) && _inputColIndices.Any(i => i == col);
}

private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);
private void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);

protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
private Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, OnnxRuntimeOutputCacher outputCacher)
{
disposer = null;
Host.AssertValue(input);
_host.AssertValue(input);

var activeOutputColNames = _parent.Outputs.Where((x, i) => activeOutput(i)).ToArray();

Expand All @@ -495,26 +501,65 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
var elemRawType = vectorType.ItemType.RawType;
var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _inputColIndices, _inputOnnxTypes, _inputTensorShapes);
if (vectorType.ItemType is TextDataViewType)
return MakeStringTensorGetter(input, iinfo, srcNamedValueGetters, activeOutputColNames);
return MakeStringTensorGetter(input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCacher);
else
return Utils.MarshalInvoke(MakeTensorGetter<int>, elemRawType, input, iinfo, srcNamedValueGetters, activeOutputColNames);
return Utils.MarshalInvoke(MakeTensorGetter<int>, elemRawType, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCacher);
}
else
{
var type = _parent.Model.ModelInfo.OutputsInfo[_parent.MapDataViewColumnToOnnxOutputTensor(iinfo)].DataViewType.RawType;
var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _inputColIndices, _inputOnnxTypes, _inputTensorShapes);
return Utils.MarshalInvoke(MakeObjectGetter<int>, type, input, iinfo, srcNamedValueGetters, activeOutputColNames);
return Utils.MarshalInvoke(MakeObjectGetter<int>, type, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCacher);
}
}

Delegate[] IRowMapper.CreateGetters(DataViewRow input, Func<int, bool> activeOutput, out Action disposer)
{
Contracts.Assert(input.Schema == _inputSchema);

OnnxRuntimeOutputCacher outputCacher = new OnnxRuntimeOutputCacher();

int n = _outputColumns.Value.Length;
var result = new Delegate[n];
for (int i = 0; i < n; i++)
{
if (!activeOutput(i))
continue;
result[i] = MakeGetter(input, i, activeOutput, outputCacher);
}
disposer = () =>
{
outputCacher.Dispose();
};
return result;
}

private class OnnxRuntimeOutputCacher
internal class OnnxRuntimeOutputCacher : IDisposable
{
public long Position;
public Dictionary<string, NamedOnnxValue> Outputs;
public Dictionary<string, DisposableNamedOnnxValue> Outputs;
public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> OutputOnnxValues;

public OnnxRuntimeOutputCacher()
{
Position = -1;
Outputs = new Dictionary<string, NamedOnnxValue>();
Outputs = new Dictionary<string, DisposableNamedOnnxValue>();
}

private bool _isDisposed;

protected virtual void Dispose(bool disposing)
{
if (_isDisposed)
return;
OutputOnnxValues?.Dispose();
_isDisposed = true;
}

public void Dispose()
{
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
}

Expand All @@ -529,46 +574,47 @@ private void UpdateCacheIfNeeded(long position, INamedOnnxValueGetter[] srcNamed
inputNameOnnxValues.Add(srcNamedOnnxValueGetters[i].GetNamedOnnxValue());
}

var outputNamedOnnxValues = _parent.Model.Run(inputNameOnnxValues);
Contracts.Assert(outputNamedOnnxValues.Count > 0);
outputCache.OutputOnnxValues?.Dispose();
outputCache.OutputOnnxValues = _parent.Model.Run(inputNameOnnxValues);
Contracts.Assert(outputCache.OutputOnnxValues.Count > 0);

foreach (var outputNameOnnxValue in outputNamedOnnxValues)
foreach (var outputNameOnnxValue in outputCache.OutputOnnxValues)
{
outputCache.Outputs[outputNameOnnxValue.Name] = outputNameOnnxValue;
}
outputCache.Position = position;
}
}

private Delegate MakeTensorGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames)
private Delegate MakeTensorGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters,
string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher)
{
Host.AssertValue(input);
var outputCacher = new OnnxRuntimeOutputCacher();
_host.AssertValue(input);
ValueGetter<VBuffer<T>> valueGetter = (ref VBuffer<T> dst) =>
{
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher);
var namedOnnxValue = outputCacher.Outputs[_parent.Outputs[iinfo]];
var tensor = namedOnnxValue.AsTensor<T>() as Microsoft.ML.OnnxRuntime.Tensors.DenseTensor<T>;
if (tensor == null)
throw Host.Except($"Output column {namedOnnxValue.Name} doesn't contain a DenseTensor of expected type {typeof(T)}");
throw _host.Except($"Output column {namedOnnxValue.Name} doesn't contain a DenseTensor of expected type {typeof(T)}");
var editor = VBufferEditor.Create(ref dst, (int)tensor.Length);
tensor.Buffer.Span.CopyTo(editor.Values);
dst = editor.Commit();
};
return valueGetter;
}

private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames)
private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters,
string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher)
{
Host.AssertValue(input);
var outputCacher = new OnnxRuntimeOutputCacher();
_host.AssertValue(input);
ValueGetter<VBuffer<ReadOnlyMemory<char>>> valueGetter = (ref VBuffer<ReadOnlyMemory<char>> dst) =>
{
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher);
var namedOnnxValue = outputCacher.Outputs[_parent.Outputs[iinfo]];
var tensor = namedOnnxValue.AsTensor<string>() as Microsoft.ML.OnnxRuntime.Tensors.DenseTensor<string>;
if (tensor == null)
throw Host.Except($"Output column {namedOnnxValue.Name} doesn't contain a DenseTensor of expected type {typeof(string)}");
throw _host.Except($"Output column {namedOnnxValue.Name} doesn't contain a DenseTensor of expected type {typeof(string)}");
// Create VBufferEditor to fill "dst" with the values in "denseTensor".
var editor = VBufferEditor.Create(ref dst, (int)tensor.Length);
Expand All @@ -580,14 +626,14 @@ private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnx
return valueGetter;
}

private Delegate MakeObjectGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames)
private Delegate MakeObjectGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters,
string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher)
{
Host.AssertValue(input);
var outputCache = new OnnxRuntimeOutputCacher();
_host.AssertValue(input);
ValueGetter<T> valueGetter = (ref T dst) =>
{
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCache);
var namedOnnxValue = outputCache.Outputs[_parent.Outputs[iinfo]];
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher);
var namedOnnxValue = outputCacher.Outputs[_parent.Outputs[iinfo]];
var trueValue = namedOnnxValue.AsEnumerable<NamedOnnxValue>().Select(value => value.AsDictionary<string, float>());
var caster = _parent.Model.ModelInfo.OutputsInfo[_parent.MapDataViewColumnToOnnxOutputTensor(iinfo)].Caster;
dst = (T)caster(namedOnnxValue);
Expand Down Expand Up @@ -664,6 +710,12 @@ private static INamedOnnxValueGetter CreateNamedOnnxValueGetterVecCore<T>(DataVi
return new NamedOnnxValueGetterVec<T>(input, colIndex, onnxShape);
}

void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);

Func<int, bool> IRowMapper.GetDependencies(Func<int, bool> activeOutput) => GetDependenciesCore(activeOutput);

public ITransformer GetTransformer() => _parent;

/// <summary>
/// Common function for wrapping ML.NET getter as a NamedOnnxValue getter.
/// </summary>
Expand Down
Loading

0 comments on commit 4b74271

Please sign in to comment.