diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ModelOperations/OnnxConversion.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/ModelOperations/OnnxConversion.cs index 8a55b8fc64..2296e0cd29 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/ModelOperations/OnnxConversion.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/ModelOperations/OnnxConversion.cs @@ -82,7 +82,8 @@ public static void Example() //Create the pipeline using onnx file. var onnxModelPath = "your_path_to_sample_onnx_conversion_1.onnx"; var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath); - var onnxTransformer = onnxEstimator.Fit(trainTestOriginalData.TrainSet); + //Make sure to either use the 'using' clause or explicitly dispose the returned onnxTransformer to prevent memory leaks + using var onnxTransformer = onnxEstimator.Fit(trainTestOriginalData.TrainSet); //Inference the testset var output = transformer.Transform(trainTestOriginalData.TestSet); diff --git a/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs b/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs index baf3400eaf..5a6b15cf69 100644 --- a/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs +++ b/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs @@ -58,7 +58,7 @@ private protected abstract class MapperBase : IRowMapper { protected readonly IHost Host; protected readonly DataViewSchema InputSchema; - private readonly Lazy _outputColumns; + protected readonly Lazy OutputColumns; private readonly RowToRowTransformerBase _parent; protected MapperBase(IHost host, DataViewSchema inputSchema, RowToRowTransformerBase parent) @@ -68,21 +68,21 @@ protected MapperBase(IHost host, DataViewSchema inputSchema, RowToRowTransformer Host = host; InputSchema = inputSchema; _parent = parent; - _outputColumns = new Lazy(GetOutputColumnsCore); + OutputColumns = new Lazy(GetOutputColumnsCore); } protected abstract DataViewSchema.DetachedColumn[] GetOutputColumnsCore(); - DataViewSchema.DetachedColumn[] IRowMapper.GetOutputColumns() => _outputColumns.Value; + DataViewSchema.DetachedColumn[] IRowMapper.GetOutputColumns() => OutputColumns.Value; - Delegate[] IRowMapper.CreateGetters(DataViewRow input, Func activeOutput, out Action disposer) + public virtual Delegate[] CreateGetters(DataViewRow input, Func activeOutput, out Action disposer) { // REVIEW: it used to be that the mapper's input schema in the constructor was required to be reference-equal to the schema // of the input row. // It still has to be the same schema, but because we may make a transition from lazy to eager schema, the reference-equality // is no longer always possible. So, we relax the assert as below. Contracts.Assert(input.Schema == InputSchema); - int n = _outputColumns.Value.Length; + int n = OutputColumns.Value.Length; var result = new Delegate[n]; var disposers = new Action[n]; for (int i = 0; i < n; i++) diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs b/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs index 878fabf6c5..9932019e7a 100644 --- a/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs +++ b/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs @@ -484,8 +484,10 @@ private protected override Func GetDependenciesCore(Func a private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func activeOutput, out Action disposer) + => throw new NotImplementedException("This should never be called!"); + + private Delegate CreateGetter(DataViewRow input, int iinfo, Func activeOutput, OnnxRuntimeOutputCacher outputCacher) { - disposer = null; Host.AssertValue(input); var activeOutputColNames = _parent.Outputs.Where((x, i) => activeOutput(i)).ToArray(); @@ -495,26 +497,59 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func, elemRawType, input, iinfo, srcNamedValueGetters, activeOutputColNames); + return Utils.MarshalInvoke(MakeTensorGetter, elemRawType, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCacher); } else { var type = _parent.Model.ModelInfo.OutputsInfo[_parent.MapDataViewColumnToOnnxOutputTensor(iinfo)].DataViewType.RawType; var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _inputColIndices, _inputOnnxTypes, _inputTensorShapes); - return Utils.MarshalInvoke(MakeObjectGetter, type, input, iinfo, srcNamedValueGetters, activeOutputColNames); + return Utils.MarshalInvoke(MakeObjectGetter, type, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCacher); + } + } + + public override Delegate[] CreateGetters(DataViewRow input, Func activeOutput, out Action disposer) + { + Contracts.Assert(input.Schema == InputSchema); + + OnnxRuntimeOutputCacher outputCacher = new OnnxRuntimeOutputCacher(); + + int n = OutputColumns.Value.Length; + var result = new Delegate[n]; + for (int i = 0; i < n; i++) + { + if (!activeOutput(i)) + continue; + result[i] = CreateGetter(input, i, activeOutput, outputCacher); } + disposer = () => + { + outputCacher.Dispose(); + }; + return result; } - private class OnnxRuntimeOutputCacher + private sealed class OnnxRuntimeOutputCacher : IDisposable { public long Position; - public Dictionary Outputs; + public Dictionary Outputs; + public IDisposableReadOnlyCollection OutputOnnxValues; + public OnnxRuntimeOutputCacher() { Position = -1; - Outputs = new Dictionary(); + Outputs = new Dictionary(); + } + + private bool _isDisposed; + + public void Dispose() + { + if (_isDisposed) + return; + OutputOnnxValues?.Dispose(); + _isDisposed = true; } } @@ -529,10 +564,11 @@ private void UpdateCacheIfNeeded(long position, INamedOnnxValueGetter[] srcNamed inputNameOnnxValues.Add(srcNamedOnnxValueGetters[i].GetNamedOnnxValue()); } - var outputNamedOnnxValues = _parent.Model.Run(inputNameOnnxValues); - Contracts.Assert(outputNamedOnnxValues.Count > 0); + outputCache.OutputOnnxValues?.Dispose(); + outputCache.OutputOnnxValues = _parent.Model.Run(inputNameOnnxValues); + Contracts.Assert(outputCache.OutputOnnxValues.Count > 0); - foreach (var outputNameOnnxValue in outputNamedOnnxValues) + foreach (var outputNameOnnxValue in outputCache.OutputOnnxValues) { outputCache.Outputs[outputNameOnnxValue.Name] = outputNameOnnxValue; } @@ -540,10 +576,10 @@ private void UpdateCacheIfNeeded(long position, INamedOnnxValueGetter[] srcNamed } } - private Delegate MakeTensorGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames) + private Delegate MakeTensorGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, + string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher) { Host.AssertValue(input); - var outputCacher = new OnnxRuntimeOutputCacher(); ValueGetter> valueGetter = (ref VBuffer dst) => { UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher); @@ -558,10 +594,11 @@ private Delegate MakeTensorGetter(DataViewRow input, int iinfo, INamedOnnxVal return valueGetter; } - private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames) + private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, + string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher) { Host.AssertValue(input); - var outputCacher = new OnnxRuntimeOutputCacher(); + ValueGetter>> valueGetter = (ref VBuffer> dst) => { UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher); @@ -580,14 +617,15 @@ private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnx return valueGetter; } - private Delegate MakeObjectGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames) + private Delegate MakeObjectGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, + string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher) { Host.AssertValue(input); - var outputCache = new OnnxRuntimeOutputCacher(); + ValueGetter valueGetter = (ref T dst) => { - UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCache); - var namedOnnxValue = outputCache.Outputs[_parent.Outputs[iinfo]]; + UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher); + var namedOnnxValue = outputCacher.Outputs[_parent.Outputs[iinfo]]; var trueValue = namedOnnxValue.AsEnumerable().Select(value => value.AsDictionary()); var caster = _parent.Model.ModelInfo.OutputsInfo[_parent.MapDataViewColumnToOnnxOutputTensor(iinfo)].Caster; dst = (T)caster(namedOnnxValue); diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs b/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs index 02c24b1ad9..4adf18fa40 100644 --- a/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs +++ b/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs @@ -198,40 +198,49 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = _session = new InferenceSession(modelFile); } - // Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime - // doesn't expose full type information via its C# APIs. - ModelFile = modelFile; - var model = new OnnxCSharpToProtoWrapper.ModelProto(); - using (var modelStream = File.OpenRead(modelFile)) - using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, 10)) - model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream); - - // Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's. - var inputTypePool = new Dictionary(); - foreach (var valueInfo in model.Graph.Input) - inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type); - - var initializerTypePool = new Dictionary(); - foreach (var valueInfo in model.Graph.Initializer) - initializerTypePool[valueInfo.Name] = OnnxTypeParser.GetScalarDataViewType(valueInfo.DataType); - - var outputTypePool = new Dictionary(); - // Build casters which maps NamedOnnxValue to .NET objects. - var casterPool = new Dictionary>(); - foreach (var valueInfo in model.Graph.Output) + try { - outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type); - casterPool[valueInfo.Name] = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType); - } + // Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime + // doesn't expose full type information via its C# APIs. + ModelFile = modelFile; + var model = new OnnxCSharpToProtoWrapper.ModelProto(); + using (var modelStream = File.OpenRead(modelFile)) + using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, 10)) + model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream); + + // Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's. + var inputTypePool = new Dictionary(); + foreach (var valueInfo in model.Graph.Input) + inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type); + + var initializerTypePool = new Dictionary(); + foreach (var valueInfo in model.Graph.Initializer) + initializerTypePool[valueInfo.Name] = OnnxTypeParser.GetScalarDataViewType(valueInfo.DataType); + + var outputTypePool = new Dictionary(); + // Build casters which maps NamedOnnxValue to .NET objects. + var casterPool = new Dictionary>(); + foreach (var valueInfo in model.Graph.Output) + { + outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type); + casterPool[valueInfo.Name] = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType); + } - var inputInfos = GetOnnxVariablesFromMetadata(_session.InputMetadata, shapeDictionary, inputTypePool, null); - var outputInfos = GetOnnxVariablesFromMetadata(_session.OutputMetadata, shapeDictionary, outputTypePool, casterPool); - var overrideableInitializers = GetOnnxVariablesFromMetadata(_session.OverridableInitializerMetadata, shapeDictionary, inputTypePool, null); + var inputInfos = GetOnnxVariablesFromMetadata(_session.InputMetadata, shapeDictionary, inputTypePool, null); + var outputInfos = GetOnnxVariablesFromMetadata(_session.OutputMetadata, shapeDictionary, outputTypePool, casterPool); + var overrideableInitializers = GetOnnxVariablesFromMetadata(_session.OverridableInitializerMetadata, shapeDictionary, inputTypePool, null); - // Create a view to the used ONNX model from ONNXRuntime's perspective. - ModelInfo = new OnnxModelInfo(inputInfos, outputInfos, overrideableInitializers); + // Create a view to the used ONNX model from ONNXRuntime's perspective. + ModelInfo = new OnnxModelInfo(inputInfos, outputInfos, overrideableInitializers); - Graph = model.Graph; + Graph = model.Graph; + } + catch + { + _session.Dispose(); + _session = null; + throw; + } } private List GetOnnxVariablesFromMetadata(IReadOnlyDictionary nodeMetadata, @@ -350,7 +359,7 @@ public static OnnxModel CreateFromBytes(byte[] modelBytes, int? gpuDeviceId = nu /// /// The NamedOnnxValues to score. /// Resulting output NamedOnnxValues list. - public IReadOnlyCollection Run(List inputNamedOnnxValues) + public IDisposableReadOnlyCollection Run(List inputNamedOnnxValues) { return _session.Run(inputNamedOnnxValues); } diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 9f97f21ad4..89b838c82e 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -788,7 +788,7 @@ public void RemoveVariablesInPipelineTest() .Append(mlContext.Transforms.NormalizeMinMax("Features")) .Append(mlContext.BinaryClassification.Trainers.FastTree(labelColumnName: "Label", featureColumnName: "Features", numberOfLeaves: 2, numberOfTrees: 1, minimumExampleCountPerLeaf: 2)); - var model = pipeline.Fit(data); + using var model = pipeline.Fit(data); var transformedData = model.Transform(data); var onnxConversionContext = new OnnxContextImpl(mlContext, "A Simple Pipeline", "ML.NET", "0", 0, "machinelearning.dotnet", OnnxVersion.Stable); @@ -2029,7 +2029,7 @@ private void TestPipeline(EstimatorChain(EstimatorChain pipeline, IDataView dataView, string onnxFileName, ColumnComparison[] columnsToCompare, string onnxTxtName = null, string onnxTxtSubDir = null) where TLastTransformer : class, ITransformer { - var model = pipeline.Fit(dataView); + using var model = pipeline.Fit(dataView); var transformedData = model.Transform(dataView); var onnxModel = ML.Model.ConvertToOnnxProtobuf(model, dataView);