-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixed memory leaks from OnnxTransformer #5518
Changes from 2 commits
a5875a4
07bc97b
db02163
4a903d2
645a47f
e14eafc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -362,8 +362,12 @@ public void Dispose() | |||||||||||||||||||||||||||||||||||||||||
_isDisposed = true; | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
private sealed class Mapper : MapperBase | ||||||||||||||||||||||||||||||||||||||||||
private sealed class Mapper : IRowMapper | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
private readonly IHost _host; | ||||||||||||||||||||||||||||||||||||||||||
private readonly DataViewSchema _inputSchema; | ||||||||||||||||||||||||||||||||||||||||||
private readonly Lazy<DataViewSchema.DetachedColumn[]> _outputColumns; | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
private readonly OnnxTransformer _parent; | ||||||||||||||||||||||||||||||||||||||||||
/// <summary> | ||||||||||||||||||||||||||||||||||||||||||
/// <see cref="_inputColIndices"/>'s i-th element value tells the <see cref="IDataView"/> column index to | ||||||||||||||||||||||||||||||||||||||||||
|
@@ -379,9 +383,11 @@ private sealed class Mapper : MapperBase | |||||||||||||||||||||||||||||||||||||||||
/// </summary> | ||||||||||||||||||||||||||||||||||||||||||
private readonly Type[] _inputOnnxTypes; | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) : | ||||||||||||||||||||||||||||||||||||||||||
base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema, parent) | ||||||||||||||||||||||||||||||||||||||||||
public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
_host = Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)); | ||||||||||||||||||||||||||||||||||||||||||
_inputSchema = inputSchema; | ||||||||||||||||||||||||||||||||||||||||||
_outputColumns = new Lazy<DataViewSchema.DetachedColumn[]>(GetOutputColumnsCore); | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
_parent = parent; | ||||||||||||||||||||||||||||||||||||||||||
_inputColIndices = new int[_parent.Inputs.Length]; | ||||||||||||||||||||||||||||||||||||||||||
|
@@ -401,15 +407,15 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) : | |||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
var col = inputSchema.GetColumnOrNull(_parent.Inputs[i]); | ||||||||||||||||||||||||||||||||||||||||||
if (!col.HasValue) | ||||||||||||||||||||||||||||||||||||||||||
throw Host.ExceptSchemaMismatch(nameof(inputSchema),"input", _parent.Inputs[i]); | ||||||||||||||||||||||||||||||||||||||||||
throw _host.ExceptSchemaMismatch(nameof(inputSchema),"input", _parent.Inputs[i]); | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
_inputColIndices[i] = col.Value.Index; | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
var type = inputSchema[_inputColIndices[i]].Type; | ||||||||||||||||||||||||||||||||||||||||||
var vectorType = type as VectorDataViewType; | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
if (vectorType != null && vectorType.Size == 0) | ||||||||||||||||||||||||||||||||||||||||||
throw Host.Except($"Variable length input columns not supported"); | ||||||||||||||||||||||||||||||||||||||||||
throw _host.Except($"Variable length input columns not supported"); | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
var itemType = type.GetItemType(); | ||||||||||||||||||||||||||||||||||||||||||
var nodeItemType = inputNodeInfo.DataViewType.GetItemType(); | ||||||||||||||||||||||||||||||||||||||||||
|
@@ -421,7 +427,7 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) : | |||||||||||||||||||||||||||||||||||||||||
// This is done to support a corner case originated in NimbusML. For more info, see: https://github.com/microsoft/NimbusML/issues/426 | ||||||||||||||||||||||||||||||||||||||||||
var isKeyType = itemType is KeyDataViewType; | ||||||||||||||||||||||||||||||||||||||||||
if (!isKeyType || itemType.RawType != nodeItemType.RawType) | ||||||||||||||||||||||||||||||||||||||||||
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputNodeInfo.DataViewType.GetItemType().ToString(), type.ToString()); | ||||||||||||||||||||||||||||||||||||||||||
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputNodeInfo.DataViewType.GetItemType().ToString(), type.ToString()); | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
// If the column is one dimension we make sure that the total size of the Onnx shape matches. | ||||||||||||||||||||||||||||||||||||||||||
|
@@ -433,8 +439,9 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) : | |||||||||||||||||||||||||||||||||||||||||
throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {String.Join(",", inputShape)}, but input data is of length {typeValueCount}."); | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
DataViewSchema.DetachedColumn[] IRowMapper.GetOutputColumns() => _outputColumns.Value; | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() | ||||||||||||||||||||||||||||||||||||||||||
private DataViewSchema.DetachedColumn[] GetOutputColumnsCore() | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
var stdSuffix = ".output"; | ||||||||||||||||||||||||||||||||||||||||||
var info = new DataViewSchema.DetachedColumn[_parent.Outputs.Length]; | ||||||||||||||||||||||||||||||||||||||||||
|
@@ -476,17 +483,16 @@ private void AddSlotNames(string columnName, DataViewSchema.Annotations.Builder | |||||||||||||||||||||||||||||||||||||||||
builder.AddSlotNames(count, getter); | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput) | ||||||||||||||||||||||||||||||||||||||||||
private Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput) | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
return col => Enumerable.Range(0, _parent.Outputs.Length).Any(i => activeOutput(i)) && _inputColIndices.Any(i => i == col); | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); | ||||||||||||||||||||||||||||||||||||||||||
private void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer) | ||||||||||||||||||||||||||||||||||||||||||
private Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, OnnxRuntimeOutputCacher outputCacher) | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
disposer = null; | ||||||||||||||||||||||||||||||||||||||||||
Host.AssertValue(input); | ||||||||||||||||||||||||||||||||||||||||||
_host.AssertValue(input); | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
var activeOutputColNames = _parent.Outputs.Where((x, i) => activeOutput(i)).ToArray(); | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
|
@@ -495,26 +501,65 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b | |||||||||||||||||||||||||||||||||||||||||
var elemRawType = vectorType.ItemType.RawType; | ||||||||||||||||||||||||||||||||||||||||||
var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _inputColIndices, _inputOnnxTypes, _inputTensorShapes); | ||||||||||||||||||||||||||||||||||||||||||
if (vectorType.ItemType is TextDataViewType) | ||||||||||||||||||||||||||||||||||||||||||
return MakeStringTensorGetter(input, iinfo, srcNamedValueGetters, activeOutputColNames); | ||||||||||||||||||||||||||||||||||||||||||
return MakeStringTensorGetter(input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCacher); | ||||||||||||||||||||||||||||||||||||||||||
else | ||||||||||||||||||||||||||||||||||||||||||
return Utils.MarshalInvoke(MakeTensorGetter<int>, elemRawType, input, iinfo, srcNamedValueGetters, activeOutputColNames); | ||||||||||||||||||||||||||||||||||||||||||
return Utils.MarshalInvoke(MakeTensorGetter<int>, elemRawType, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCacher); | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
else | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
var type = _parent.Model.ModelInfo.OutputsInfo[_parent.MapDataViewColumnToOnnxOutputTensor(iinfo)].DataViewType.RawType; | ||||||||||||||||||||||||||||||||||||||||||
var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _inputColIndices, _inputOnnxTypes, _inputTensorShapes); | ||||||||||||||||||||||||||||||||||||||||||
return Utils.MarshalInvoke(MakeObjectGetter<int>, type, input, iinfo, srcNamedValueGetters, activeOutputColNames); | ||||||||||||||||||||||||||||||||||||||||||
return Utils.MarshalInvoke(MakeObjectGetter<int>, type, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCacher); | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
Delegate[] IRowMapper.CreateGetters(DataViewRow input, Func<int, bool> activeOutput, out Action disposer) | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
Contracts.Assert(input.Schema == _inputSchema); | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
OnnxRuntimeOutputCacher outputCacher = new OnnxRuntimeOutputCacher(); | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
int n = _outputColumns.Value.Length; | ||||||||||||||||||||||||||||||||||||||||||
var result = new Delegate[n]; | ||||||||||||||||||||||||||||||||||||||||||
for (int i = 0; i < n; i++) | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
if (!activeOutput(i)) | ||||||||||||||||||||||||||||||||||||||||||
continue; | ||||||||||||||||||||||||||||||||||||||||||
result[i] = MakeGetter(input, i, activeOutput, outputCacher); | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
disposer = () => | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
outputCacher.Dispose(); | ||||||||||||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||||||||||||
return result; | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
private class OnnxRuntimeOutputCacher | ||||||||||||||||||||||||||||||||||||||||||
internal class OnnxRuntimeOutputCacher : IDisposable | ||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just wondering, why did this needed to change from private to internal? #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch. That is a left over from a previous experiment. I will change it back. In reply to: 533789998 [](ancestors = 533789998) |
||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
public long Position; | ||||||||||||||||||||||||||||||||||||||||||
public Dictionary<string, NamedOnnxValue> Outputs; | ||||||||||||||||||||||||||||||||||||||||||
public Dictionary<string, DisposableNamedOnnxValue> Outputs; | ||||||||||||||||||||||||||||||||||||||||||
public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> OutputOnnxValues; | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
public OnnxRuntimeOutputCacher() | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
Position = -1; | ||||||||||||||||||||||||||||||||||||||||||
Outputs = new Dictionary<string, NamedOnnxValue>(); | ||||||||||||||||||||||||||||||||||||||||||
Outputs = new Dictionary<string, DisposableNamedOnnxValue>(); | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
private bool _isDisposed; | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
protected virtual void Dispose(bool disposing) | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
if (_isDisposed) | ||||||||||||||||||||||||||||||||||||||||||
return; | ||||||||||||||||||||||||||||||||||||||||||
OutputOnnxValues?.Dispose(); | ||||||||||||||||||||||||||||||||||||||||||
_isDisposed = true; | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
public void Dispose() | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
Dispose(disposing: true); | ||||||||||||||||||||||||||||||||||||||||||
GC.SuppressFinalize(this); | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
This is a |
||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
|
@@ -529,46 +574,47 @@ private void UpdateCacheIfNeeded(long position, INamedOnnxValueGetter[] srcNamed | |||||||||||||||||||||||||||||||||||||||||
inputNameOnnxValues.Add(srcNamedOnnxValueGetters[i].GetNamedOnnxValue()); | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
var outputNamedOnnxValues = _parent.Model.Run(inputNameOnnxValues); | ||||||||||||||||||||||||||||||||||||||||||
Contracts.Assert(outputNamedOnnxValues.Count > 0); | ||||||||||||||||||||||||||||||||||||||||||
outputCache.OutputOnnxValues?.Dispose(); | ||||||||||||||||||||||||||||||||||||||||||
outputCache.OutputOnnxValues = _parent.Model.Run(inputNameOnnxValues); | ||||||||||||||||||||||||||||||||||||||||||
Contracts.Assert(outputCache.OutputOnnxValues.Count > 0); | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
foreach (var outputNameOnnxValue in outputNamedOnnxValues) | ||||||||||||||||||||||||||||||||||||||||||
foreach (var outputNameOnnxValue in outputCache.OutputOnnxValues) | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
outputCache.Outputs[outputNameOnnxValue.Name] = outputNameOnnxValue; | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
outputCache.Position = position; | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
private Delegate MakeTensorGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames) | ||||||||||||||||||||||||||||||||||||||||||
private Delegate MakeTensorGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, | ||||||||||||||||||||||||||||||||||||||||||
string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher) | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
Host.AssertValue(input); | ||||||||||||||||||||||||||||||||||||||||||
var outputCacher = new OnnxRuntimeOutputCacher(); | ||||||||||||||||||||||||||||||||||||||||||
_host.AssertValue(input); | ||||||||||||||||||||||||||||||||||||||||||
ValueGetter<VBuffer<T>> valueGetter = (ref VBuffer<T> dst) => | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher); | ||||||||||||||||||||||||||||||||||||||||||
var namedOnnxValue = outputCacher.Outputs[_parent.Outputs[iinfo]]; | ||||||||||||||||||||||||||||||||||||||||||
var tensor = namedOnnxValue.AsTensor<T>() as Microsoft.ML.OnnxRuntime.Tensors.DenseTensor<T>; | ||||||||||||||||||||||||||||||||||||||||||
if (tensor == null) | ||||||||||||||||||||||||||||||||||||||||||
throw Host.Except($"Output column {namedOnnxValue.Name} doesn't contain a DenseTensor of expected type {typeof(T)}"); | ||||||||||||||||||||||||||||||||||||||||||
throw _host.Except($"Output column {namedOnnxValue.Name} doesn't contain a DenseTensor of expected type {typeof(T)}"); | ||||||||||||||||||||||||||||||||||||||||||
var editor = VBufferEditor.Create(ref dst, (int)tensor.Length); | ||||||||||||||||||||||||||||||||||||||||||
tensor.Buffer.Span.CopyTo(editor.Values); | ||||||||||||||||||||||||||||||||||||||||||
dst = editor.Commit(); | ||||||||||||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||||||||||||
return valueGetter; | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames) | ||||||||||||||||||||||||||||||||||||||||||
private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, | ||||||||||||||||||||||||||||||||||||||||||
string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher) | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
Host.AssertValue(input); | ||||||||||||||||||||||||||||||||||||||||||
var outputCacher = new OnnxRuntimeOutputCacher(); | ||||||||||||||||||||||||||||||||||||||||||
_host.AssertValue(input); | ||||||||||||||||||||||||||||||||||||||||||
ValueGetter<VBuffer<ReadOnlyMemory<char>>> valueGetter = (ref VBuffer<ReadOnlyMemory<char>> dst) => | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher); | ||||||||||||||||||||||||||||||||||||||||||
var namedOnnxValue = outputCacher.Outputs[_parent.Outputs[iinfo]]; | ||||||||||||||||||||||||||||||||||||||||||
var tensor = namedOnnxValue.AsTensor<string>() as Microsoft.ML.OnnxRuntime.Tensors.DenseTensor<string>; | ||||||||||||||||||||||||||||||||||||||||||
if (tensor == null) | ||||||||||||||||||||||||||||||||||||||||||
throw Host.Except($"Output column {namedOnnxValue.Name} doesn't contain a DenseTensor of expected type {typeof(string)}"); | ||||||||||||||||||||||||||||||||||||||||||
throw _host.Except($"Output column {namedOnnxValue.Name} doesn't contain a DenseTensor of expected type {typeof(string)}"); | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
// Create VBufferEditor to fill "dst" with the values in "denseTensor". | ||||||||||||||||||||||||||||||||||||||||||
var editor = VBufferEditor.Create(ref dst, (int)tensor.Length); | ||||||||||||||||||||||||||||||||||||||||||
|
@@ -580,14 +626,14 @@ private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnx | |||||||||||||||||||||||||||||||||||||||||
return valueGetter; | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
private Delegate MakeObjectGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames) | ||||||||||||||||||||||||||||||||||||||||||
private Delegate MakeObjectGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, | ||||||||||||||||||||||||||||||||||||||||||
string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher) | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
Host.AssertValue(input); | ||||||||||||||||||||||||||||||||||||||||||
var outputCache = new OnnxRuntimeOutputCacher(); | ||||||||||||||||||||||||||||||||||||||||||
_host.AssertValue(input); | ||||||||||||||||||||||||||||||||||||||||||
ValueGetter<T> valueGetter = (ref T dst) => | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCache); | ||||||||||||||||||||||||||||||||||||||||||
var namedOnnxValue = outputCache.Outputs[_parent.Outputs[iinfo]]; | ||||||||||||||||||||||||||||||||||||||||||
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher); | ||||||||||||||||||||||||||||||||||||||||||
var namedOnnxValue = outputCacher.Outputs[_parent.Outputs[iinfo]]; | ||||||||||||||||||||||||||||||||||||||||||
var trueValue = namedOnnxValue.AsEnumerable<NamedOnnxValue>().Select(value => value.AsDictionary<string, float>()); | ||||||||||||||||||||||||||||||||||||||||||
var caster = _parent.Model.ModelInfo.OutputsInfo[_parent.MapDataViewColumnToOnnxOutputTensor(iinfo)].Caster; | ||||||||||||||||||||||||||||||||||||||||||
dst = (T)caster(namedOnnxValue); | ||||||||||||||||||||||||||||||||||||||||||
|
@@ -664,6 +710,12 @@ private static INamedOnnxValueGetter CreateNamedOnnxValueGetterVecCore<T>(DataVi | |||||||||||||||||||||||||||||||||||||||||
return new NamedOnnxValueGetterVec<T>(input, colIndex, onnxShape); | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
Func<int, bool> IRowMapper.GetDependencies(Func<int, bool> activeOutput) => GetDependenciesCore(activeOutput); | ||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's no need for these. Instead you can just rename The same for |
||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
public ITransformer GetTransformer() => _parent; | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
/// <summary> | ||||||||||||||||||||||||||||||||||||||||||
/// Common function for wrapping ML.NET getter as a NamedOnnxValue getter. | ||||||||||||||||||||||||||||||||||||||||||
/// </summary> | ||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, isn't it possible to mark this method as virtual on the MapperBase class and simply override it here? although I guess implementing IRowMapper directly isn't very bad, it's assymetric with the pattern that all RowToRowTransformerBase children (such as onnxtransformer) use a mapper that derives from MapperBase. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! I don't know why I didn't think of trying it. Let me try it and see. It will obviously a cleaner solution.
In reply to: 533779304 [](ancestors = 533779304)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like that would make it messier. MakeGetter is also declared in MapperBase as an abstract function and I will have to implement that. And the signature for that is not correct for what OnnxTransformer needs. I can have my version of CreateGetters not call MakeGetter, but then I still need to implement MakeGetter to satisfy the base class. That leaves it a bit inconsistent.
In reply to: 533781345 [](ancestors = 533781345,533779304)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. I've tried to think of an alternative to this, but this does seem to be the best solution to avoid refactoring more things. So I guess it's ok to leave it this way. I guess this is why OnnxTransformer used to have a Cacher for each column, but as you mentioned, this meant no actual caching was done. 😕 #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can just have your override of
MakeGetter
throw new NotSupportedException("this should never be called.");
#ResolvedThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea! I have implemented this and reverted back to inheriting from MapperBase
In reply to: 533817299 [](ancestors = 533817299)