From a627d5b02d14ff21c1a31a94b1904261211431f6 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Thu, 20 Sep 2018 10:22:51 -0700 Subject: [PATCH] Create API for extracting information about the nodes in a TensorFlow model (#862) * Add a method that returns TensorFlow model outputs as an ISchema. * Update after merge with master * Address PR comments. * Add metadata with information about the operation type, and the inputs needed for it. * Add method that returns an enumerable of the information about graph nodes, and a console app that displays it * Add the DnnAnalyzer project files. * Address code review comments * Make needed changes after merge with master * Fix bug when there is a node with 1 dimension that is unknown --- Microsoft.ML.sln | 11 ++ src/Microsoft.ML.Data/DataView/SimpleRow.cs | 126 +++++++++----- .../Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs | 31 ++++ .../Microsoft.ML.DnnAnalyzer.csproj | 19 ++ .../TensorFlow/Tensorflow.cs | 81 +++++++-- .../TensorFlow/TensorflowUtils.cs | 164 +++++++++++++++++- .../TensorflowTransform.cs | 33 +--- .../TensorflowTests.cs | 90 ++++++++++ 8 files changed, 476 insertions(+), 79 deletions(-) create mode 100644 src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs create mode 100644 src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer.csproj diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln index 211a6f7ca3..012d56e705 100644 --- a/Microsoft.ML.sln +++ b/Microsoft.ML.sln @@ -115,6 +115,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Analyzer", "sr EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.StaticPipelineTesting", "test\Microsoft.ML.StaticPipelineTesting\Microsoft.ML.StaticPipelineTesting.csproj", "{8B38BF24-35F4-4787-A9C5-22D35987106E}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.DnnAnalyzer", "src\Microsoft.ML.DnnAnalyzer\Microsoft.ML.DnnAnalyzer\Microsoft.ML.DnnAnalyzer.csproj", "{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -419,6 +421,14 @@ Global {8B38BF24-35F4-4787-A9C5-22D35987106E}.Release|Any CPU.Build.0 = Release|Any CPU {8B38BF24-35F4-4787-A9C5-22D35987106E}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU {8B38BF24-35F4-4787-A9C5-22D35987106E}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU + {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Debug|Any CPU.Build.0 = Debug|Any CPU + {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU + {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU + {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Release|Any CPU.ActiveCfg = Release|Any CPU + {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Release|Any CPU.Build.0 = Release|Any CPU + {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU + {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -466,6 +476,7 @@ Global {570A0B8A-5463-44D2-8521-54C0CA4CACA9} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {6DEF0F40-3853-47B3-8165-5F24BA5E14DF} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {8B38BF24-35F4-4787-A9C5-22D35987106E} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} + {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA} = {09EADF06-BE25-4228-AB53-95AE3E15B530} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D} diff --git a/src/Microsoft.ML.Data/DataView/SimpleRow.cs b/src/Microsoft.ML.Data/DataView/SimpleRow.cs index b94f8f90ec..b0ba12b5ab 100644 --- a/src/Microsoft.ML.Data/DataView/SimpleRow.cs +++ b/src/Microsoft.ML.Data/DataView/SimpleRow.cs @@ -64,97 +64,135 @@ public bool IsColumnActive(int col) /// An that takes all column names and types as constructor parameters. /// The columns do not have metadata. /// - public sealed class SimpleSchema : ISchema + public abstract class SimpleSchemaBase : ISchema { - private readonly IExceptionContext _ectx; + protected readonly IExceptionContext Ectx; private readonly string[] _names; - private readonly ColumnType[] _types; - private readonly Dictionary _columnNameMap; - private readonly MetadataUtils.MetadataGetter>>[] _keyValueGetters; + protected readonly ColumnType[] Types; + protected readonly Dictionary ColumnNameMap; - public int ColumnCount => _types.Length; + public int ColumnCount => Types.Length; - public SimpleSchema(IExceptionContext ectx, params KeyValuePair[] columns) + protected SimpleSchemaBase(IExceptionContext ectx, params KeyValuePair[] columns) { Contracts.CheckValueOrNull(ectx); - _ectx = ectx; - _ectx.CheckValue(columns, nameof(columns)); + Ectx = ectx; + Ectx.CheckValue(columns, nameof(columns)); _names = new string[columns.Length]; - _types = new ColumnType[columns.Length]; - _columnNameMap = new Dictionary(); + Types = new ColumnType[columns.Length]; + ColumnNameMap = new Dictionary(); for (int i = 0; i < columns.Length; i++) { _names[i] = columns[i].Key; - _types[i] = columns[i].Value; - if (_columnNameMap.ContainsKey(columns[i].Key)) + Types[i] = columns[i].Value; + if (ColumnNameMap.ContainsKey(columns[i].Key)) throw ectx.ExceptParam(nameof(columns), $"Duplicate column name: '{columns[i].Key}'"); - _columnNameMap[columns[i].Key] = i; - } - _keyValueGetters = new MetadataUtils.MetadataGetter>>[ColumnCount]; - } - - public SimpleSchema(IExceptionContext ectx, KeyValuePair[] columns, Dictionary>>> keyValues) - : this(ectx, columns) - { - foreach (var kvp in keyValues) - { - var name = kvp.Key; - var getter = kvp.Value; - if (!_columnNameMap.TryGetValue(name, out int col)) - throw _ectx.ExceptParam(nameof(keyValues), $"Output schema does not contain column '{name}'"); - if (!_types[col].ItemType.IsKey) - throw _ectx.ExceptParam(nameof(keyValues), $"Column '{name}' is not a key column, so it cannot have key value metadata"); - _keyValueGetters[col] = getter; + ColumnNameMap[columns[i].Key] = i; } } public bool TryGetColumnIndex(string name, out int col) { - return _columnNameMap.TryGetValue(name, out col); + return ColumnNameMap.TryGetValue(name, out col); } public string GetColumnName(int col) { - _ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); + Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); return _names[col]; } public ColumnType GetColumnType(int col) { - _ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return _types[col]; + Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); + return Types[col]; } public IEnumerable> GetMetadataTypes(int col) { - _ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); + Ectx.Assert(0 <= col && col < ColumnCount); + return GetMetadataTypesCore(col); + } + + protected abstract IEnumerable> GetMetadataTypesCore(int col); + + public ColumnType GetMetadataTypeOrNull(string kind, int col) + { + Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); + return GetMetadataTypeOrNullCore(kind, col); + } + + protected abstract ColumnType GetMetadataTypeOrNullCore(string kind, int col); + + public void GetMetadata(string kind, int col, ref TValue value) + { + Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); + GetMetadataCore(kind, col, ref value); + } + + protected abstract void GetMetadataCore(string kind, int col, ref TValue value); + } + + /// + /// An that takes all column names and types as constructor parameters. + /// The columns can optionally have text metadata. + /// + public sealed class SimpleSchema : SimpleSchemaBase + { + private readonly MetadataUtils.MetadataGetter>>[] _keyValueGetters; + + public SimpleSchema(IExceptionContext ectx, params KeyValuePair[] columns) + : base(ectx, columns) + { + _keyValueGetters = new MetadataUtils.MetadataGetter>>[ColumnCount]; + } + + public SimpleSchema(IExceptionContext ectx, KeyValuePair[] columns, + Dictionary>>> keyValues) + : this(ectx, columns) + { + foreach (var kvp in keyValues) + { + var name = kvp.Key; + var getter = kvp.Value; + if (!ColumnNameMap.TryGetValue(name, out int col)) + throw Ectx.ExceptParam(nameof(keyValues), $"Output schema does not contain column '{name}'"); + if (!Types[col].ItemType.IsKey) + throw Ectx.ExceptParam(nameof(keyValues), $"Column '{name}' is not a key column, so it cannot have key value metadata"); + _keyValueGetters[col] = getter; + } + } + + protected override IEnumerable> GetMetadataTypesCore(int col) + { + Ectx.Assert(0 <= col && col < ColumnCount); if (_keyValueGetters[col] != null) { - _ectx.Assert(_types[col].ItemType.IsKey); + Ectx.Assert(Types[col].ItemType.IsKey); yield return new KeyValuePair(MetadataUtils.Kinds.KeyValues, - new VectorType(TextType.Instance, _types[col].ItemType.KeyCount)); + new VectorType(TextType.Instance, Types[col].ItemType.KeyCount)); } } - public ColumnType GetMetadataTypeOrNull(string kind, int col) + protected override ColumnType GetMetadataTypeOrNullCore(string kind, int col) { - _ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); + Ectx.Assert(0 <= col && col < ColumnCount); if (kind == MetadataUtils.Kinds.KeyValues && _keyValueGetters[col] != null) { - _ectx.Assert(_types[col].ItemType.IsKey); - return new VectorType(TextType.Instance, _types[col].ItemType.KeyCount); + Ectx.Assert(Types[col].ItemType.IsKey); + return new VectorType(TextType.Instance, Types[col].ItemType.KeyCount); } return null; } - public void GetMetadata(string kind, int col, ref TValue value) + protected override void GetMetadataCore(string kind, int col, ref TValue value) { - _ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); + Ectx.Assert(0 <= col && col < ColumnCount); if (kind == MetadataUtils.Kinds.KeyValues && _keyValueGetters[col] != null) _keyValueGetters[col].Marshal(col, ref value); else - throw _ectx.ExceptGetMetadata(); + throw Ectx.ExceptGetMetadata(); } } } \ No newline at end of file diff --git a/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs b/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs new file mode 100644 index 0000000000..48fd32fc31 --- /dev/null +++ b/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Transforms.TensorFlow; +using System; +using System.Linq; + +namespace Microsoft.ML.DnnAnalyzer +{ + public static class DnnAnalyzer + { + public static void Main(string[] args) + { + if (Utils.Size(args) != 1) + { + Console.Error.WriteLine("Usage: dotnet DnnAnalyzer.dll "); + return; + } + + foreach (var (name, opType, type, inputs) in TensorFlowUtils.GetModelNodes(args[0])) + { + var inputsString = inputs.Length == 0 ? "" : $", input nodes: {string.Join(", ", inputs)}"; + Console.WriteLine($"Graph node: '{name}', operation type: '{opType}', output type: '{type}'{inputsString}"); + } + } + } +} diff --git a/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer.csproj b/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer.csproj new file mode 100644 index 0000000000..7c77ff2ffa --- /dev/null +++ b/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer.csproj @@ -0,0 +1,19 @@ + + + + Exe + netcoreapp2.1 + DnnAnalyzer + Microsoft.ML.TensorFlow + + + + + + + + + + + + diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs index 4fd4258794..e63e4f56c2 100644 --- a/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs +++ b/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs @@ -23,6 +23,7 @@ using size_t = System.UIntPtr; using System.Collections.Generic; +using System.Collections; #pragma warning disable MSML_GeneralName #pragma warning disable MSML_PrivateFieldName @@ -492,7 +493,7 @@ public void SetConfig(IntPtr protoData, int length, TFStatus status = null) /// "hot", and add a "sub" operation there the result will be "demo/hot/sub". /// /// - internal partial class TFGraph : TFDisposableThreadSafe + internal partial class TFGraph : TFDisposableThreadSafe, IEnumerable { // extern TF_Graph * TF_NewGraph (); [DllImport(NativeBinding.TensorFlowLibrary)] @@ -696,6 +697,33 @@ public override string ToString() IntPtr len; return TF_GraphDebugString(Handle, out len); } + + [DllImport(NativeBinding.TensorFlowLibrary)] + private static unsafe extern TF_Operation TF_GraphNextOperation(TF_Graph graph, ref IntPtr pos); + + /// + /// Returns the enumerator that returns all the TFOperations in a graph. + /// + /// The enumerator. + private IEnumerable GetEnumerable() + { + if (handle == IntPtr.Zero) + ObjectDisposedException(); + IntPtr token = IntPtr.Zero; + IntPtr operll; + while ((operll = TF_GraphNextOperation(handle, ref token)) != IntPtr.Zero) + yield return new TFOperation(this, operll); + } + + public IEnumerator GetEnumerator() + { + return GetEnumerable().GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } } /// @@ -736,6 +764,48 @@ public TFOutput this[int idx] return new TFOutput(this, idx); } } + + // extern TF_Output TF_OperationInput (TF_Input oper_in); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern TFOutput TF_OperationInput(TFInput oper_in); + + public TFOutput GetInput(int idx) + { + return TF_OperationInput(new TFInput() { Operation = handle, Index = idx }); + } + + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern IntPtr TF_OperationName(TF_Operation oper); + + /// + /// The name for this operation/ + /// + /// The name. + public string Name => handle == IntPtr.Zero ? "" : TF_OperationName(handle).GetStr(); + + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern IntPtr TF_OperationOpType(TF_Operation oper); + + public string OpType => handle == IntPtr.Zero ? "" : TF_OperationOpType(handle).GetStr(); + + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern int TF_OperationNumOutputs(TF_Operation oper); + + /// + /// Gets the number of outputs on this operation. + /// + /// The number outputs. + public int NumOutputs => handle == IntPtr.Zero ? -1 : TF_OperationNumOutputs(handle); + + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern int TF_OperationNumInputs(TF_Operation oper); + + /// + /// Gets the number of inputs for this operation. + /// Import a serialized graph into this graph, using the specified importing options. + /// + /// The number inputs. + public int NumInputs => TF_OperationNumInputs(handle); } /// @@ -1768,15 +1838,6 @@ internal struct TFInput /// public int Index; - // extern TF_Output TF_OperationInput (TF_Input oper_in); - [DllImport(NativeBinding.TensorFlowLibrary)] - private static extern TFOutput TF_OperationInput(TFInput oper_in); - - public TFOutput GetOutput(TFInput operIn) - { - return TF_OperationInput(operIn); - } - // extern TF_DataType TF_OperationInputType (TF_Input oper_in); [DllImport(NativeBinding.TensorFlowLibrary)] private static extern TFDataType TF_OperationInputType(TFInput oper_in); diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs index e77309c8b0..54030aec91 100644 --- a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs +++ b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs @@ -4,16 +4,21 @@ using System; using System.Collections.Generic; +using System.IO; using System.Linq; using System.Runtime.InteropServices; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.ImageAnalytics.EntryPoints; +using Microsoft.ML.Runtime.Internal.Utilities; namespace Microsoft.ML.Transforms.TensorFlow { public static class TensorFlowUtils { + public const string OpType = "OpType"; + public const string InputOps = "InputOps"; + // This method is needed for the Pipeline API, since ModuleCatalog does not load entry points that are located // in assemblies that aren't directly used in the code. Users who want to use TensorFlow components will have to call // TensorFlowUtils.Initialize() before creating the pipeline. @@ -25,7 +30,95 @@ public static void Initialize() ImageAnalytics.Initialize(); } + private static ISchema GetModelSchema(IExceptionContext ectx, TFGraph graph) + { + var res = new List>(); + var opTypeGetters = new List>>(); + var inputOpsGetters = new List>>>(); + var inputOpsLengths = new List(); + foreach (var op in graph) + { + var tfType = op[0].OutputType; + var mlType = Tf2MlNetTypeOrNull(tfType); + + // If the type is not supported in ML.NET then we cannot represent it as a column in an ISchema. + // We also cannot output it with a TensorFlowTransform, so we skip it. + if (mlType == null) + continue; + + var shape = graph.GetTensorShape(op[0]); + var shapeArray = shape.ToIntArray(); + + inputOpsLengths.Add(op.NumInputs); + MetadataUtils.MetadataGetter>> inputOpsGetter = null; + if (op.NumInputs > 0) + { + var inputOps = new ReadOnlyMemory[op.NumInputs]; + for (int i = 0; i < op.NumInputs; i++) + { + var input = op.GetInput(i); + inputOps[i] = new ReadOnlyMemory(input.Operation.Name.ToArray()); + } + inputOpsGetter = (int col, ref VBuffer> dst) => + dst = new VBuffer>(op.NumInputs, inputOps); + } + inputOpsGetters.Add(inputOpsGetter); + + var opType = op.OpType; + MetadataUtils.MetadataGetter> opTypeGetter = + (int col, ref ReadOnlyMemory dst) => dst = new ReadOnlyMemory(opType.ToArray()); + opTypeGetters.Add(opTypeGetter); + + var columnType = Utils.Size(shapeArray) == 1 && shapeArray[0] == -1 ? new VectorType(mlType) : + Utils.Size(shapeArray) > 0 && shapeArray.Skip(1).All(x => x > 0) ? + new VectorType(mlType, shapeArray[0] > 0 ? shapeArray : shapeArray.Skip(1).ToArray()) + : new VectorType(mlType); + res.Add(new KeyValuePair(op.Name, columnType)); + } + return new TensorFlowSchema(ectx, res.ToArray(), opTypeGetters.ToArray(), inputOpsGetters.ToArray(), inputOpsLengths.ToArray()); + } + + public static ISchema GetModelSchema(IExceptionContext ectx, string modelFile) + { + var bytes = File.ReadAllBytes(modelFile); + var session = LoadTFSession(ectx, bytes, modelFile); + return GetModelSchema(ectx, session.Graph); + } + + public static IEnumerable<(string, string, ColumnType, string[])> GetModelNodes(string modelFile) + { + var schema = GetModelSchema(null, modelFile); + + for (int i = 0; i < schema.ColumnCount; i++) + { + var name = schema.GetColumnName(i); + var type = schema.GetColumnType(i); + + var metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, i); + Contracts.Assert(metadataType != null && metadataType.IsText); + ReadOnlyMemory opType = default; + schema.GetMetadata(TensorFlowUtils.OpType, i, ref opType); + metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, i); + VBuffer> inputOps = default; + if (metadataType != null) + { + Contracts.Assert(metadataType.IsKnownSizeVector && metadataType.ItemType.IsText); + schema.GetMetadata(TensorFlowUtils.InputOps, i, ref inputOps); + } + yield return (name, opType.ToString(), type, + Utils.Size(inputOps.Values) > 0 ? inputOps.Values.Select(input => input.ToString()).ToArray() : new string[0]); + } + } + internal static PrimitiveType Tf2MlNetType(TFDataType type) + { + var mlNetType = Tf2MlNetTypeOrNull(type); + if (mlNetType == null) + throw new NotSupportedException("TensorFlow type not supported."); + return mlNetType; + } + + private static PrimitiveType Tf2MlNetTypeOrNull(TFDataType type) { switch (type) { @@ -42,10 +135,29 @@ internal static PrimitiveType Tf2MlNetType(TFDataType type) case TFDataType.UInt64: return NumberType.U8; default: - throw new NotSupportedException("TensorFlow type not supported."); + return null; } } + internal static TFSession LoadTFSession(IExceptionContext ectx, byte[] modelBytes, string modelFile = null) + { + var graph = new TFGraph(); + try + { + graph.Import(modelBytes, ""); + } + catch (Exception ex) + { + if (!string.IsNullOrEmpty(modelFile)) + throw ectx.Except($"TensorFlow exception triggered while loading model from '{modelFile}'"); +#pragma warning disable MSML_NoMessagesForLoadContext + throw ectx.ExceptDecode(ex, "Tensorflow exception triggered while loading model."); +#pragma warning restore MSML_NoMessagesForLoadContext + + } + return new TFSession(graph); + } + internal static unsafe void FetchData(IntPtr data, T[] result) { var size = result.Length; @@ -73,5 +185,55 @@ internal static bool IsTypeSupported(TFDataType tfoutput) return false; } } + + private sealed class TensorFlowSchema : SimpleSchemaBase + { + private readonly MetadataUtils.MetadataGetter>[] _opTypeGetters; + private readonly MetadataUtils.MetadataGetter>>[] _inputOpsGetters; + private readonly int[] _inputOpsLengths; + + public TensorFlowSchema(IExceptionContext ectx, KeyValuePair[] columns, + MetadataUtils.MetadataGetter>[] opTypeGetters, + MetadataUtils.MetadataGetter>>[] inputOpsGetters, int[] inputOpsLengths) + : base(ectx, columns) + { + ectx.CheckParam(Utils.Size(opTypeGetters) == ColumnCount, nameof(opTypeGetters)); + ectx.CheckParam(Utils.Size(inputOpsGetters) == ColumnCount, nameof(inputOpsGetters)); + ectx.CheckParam(Utils.Size(inputOpsLengths) == ColumnCount, nameof(inputOpsLengths)); + + _opTypeGetters = opTypeGetters; + _inputOpsGetters = inputOpsGetters; + _inputOpsLengths = inputOpsLengths; + } + + protected override void GetMetadataCore(string kind, int col, ref TValue value) + { + Ectx.Assert(0 <= col && col < ColumnCount); + if (kind == OpType) + _opTypeGetters[col].Marshal(col, ref value); + else if (kind == InputOps && _inputOpsGetters[col] != null) + _inputOpsGetters[col].Marshal(col, ref value); + else + throw Ectx.ExceptGetMetadata(); + } + + protected override ColumnType GetMetadataTypeOrNullCore(string kind, int col) + { + Ectx.Assert(0 <= col && col < ColumnCount); + if (kind == OpType) + return TextType.Instance; + if (kind == InputOps && _inputOpsGetters[col] != null) + return new VectorType(TextType.Instance, _inputOpsLengths[col]); + return null; + } + + protected override IEnumerable> GetMetadataTypesCore(int col) + { + Ectx.Assert(0 <= col && col < ColumnCount); + yield return new KeyValuePair(OpType, TextType.Instance); + if (_inputOpsGetters[col] != null) + yield return new KeyValuePair(InputOps, new VectorType(TextType.Instance, _inputOpsLengths[col])); + } + } } } diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index 578a9d5778..69532de8fc 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -122,6 +122,7 @@ private static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext byte[] modelBytes = null; if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray())) throw env.ExceptDecode(); + var session = TensorFlowUtils.LoadTFSession(env, modelBytes); var numInputs = ctx.Reader.ReadInt32(); env.CheckDecode(numInputs > 0); string[] inputs = new string[numInputs]; @@ -138,7 +139,7 @@ private static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext for (int j = 0; j < outputs.Length; j++) outputs[j] = ctx.LoadNonEmptyString(); - return new TensorFlowTransform(env, modelBytes, inputs, outputs); + return new TensorFlowTransform(env, session, inputs, outputs); } // Factory method for SignatureDataTransform. @@ -160,27 +161,12 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); - private TFSession LoadTFSession(byte[] modelBytes) - { - var graph = new TFGraph(); - try - { - graph.Import(modelBytes, ""); - } - catch (Exception ex) - { -#pragma warning disable MSML_NoMessagesForLoadContext - throw _host.ExceptDecode(ex, "Tensorflow exception triggered while loading model."); -#pragma warning restore MSML_NoMessagesForLoadContext - } - return new TFSession(graph); - } - - private static byte[] CheckFileAndRead(IHostEnvironment env, string modelFile) + private static TFSession CheckFileAndRead(IHostEnvironment env, string modelFile) { env.CheckNonWhiteSpace(modelFile, nameof(modelFile)); env.CheckUserArg(File.Exists(modelFile), nameof(modelFile)); - return File.ReadAllBytes(modelFile); + var bytes = File.ReadAllBytes(modelFile); + return TensorFlowUtils.LoadTFSession(env, bytes, modelFile); } public TensorFlowTransform(IHostEnvironment env, string modelFile, string[] inputs, string[] outputs) : @@ -188,15 +174,14 @@ public TensorFlowTransform(IHostEnvironment env, string modelFile, string[] inpu { } - private TensorFlowTransform(IHostEnvironment env, byte[] modelBytes, string[] inputs, string[] outputs) + private TensorFlowTransform(IHostEnvironment env, TFSession session, string[] inputs, string[] outputs) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(RegistrationName)); - _host.CheckValue(modelBytes, nameof(modelBytes)); + _host.CheckValue(session, nameof(session)); _host.CheckNonEmpty(inputs, nameof(inputs)); _host.CheckNonEmpty(outputs, nameof(outputs)); - - Session = LoadTFSession(modelBytes); + Session = session; foreach (var input in inputs) { _host.CheckNonWhiteSpace(input, nameof(inputs)); @@ -204,7 +189,7 @@ private TensorFlowTransform(IHostEnvironment env, byte[] modelBytes, string[] in throw _host.ExceptParam(nameof(inputs), $"Input column '{input}' does not exist in the model"); var tfInput = new TFOutput(Session.Graph[input]); if (!TensorFlowUtils.IsTypeSupported(tfInput.OutputType)) - throw _host.ExceptParam(nameof(modelBytes), $"Input type '{tfInput.OutputType}' of input column '{input}' is not supported in TensorFlow"); + throw _host.ExceptParam(nameof(session), $"Input type '{tfInput.OutputType}' of input column '{input}' is not supported in TensorFlow"); } var newNames = new HashSet(); diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index da9102fbb2..71f5f95f33 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -10,6 +10,7 @@ using Microsoft.ML.Runtime.LightGBM; using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.TensorFlow; +using System; using System.Collections.Generic; using System.IO; using Xunit; @@ -181,6 +182,95 @@ public void TensorFlowTransformInceptionTest() } } + [Fact] + public void TensorFlowInputsOutputsSchemaTest() + { + using (var env = new ConsoleEnvironment(seed: 1, conc: 1)) + { + var model_location = "mnist_model/frozen_saved_model.pb"; + var schema = TensorFlowUtils.GetModelSchema(env, model_location); + Assert.Equal(54, schema.ColumnCount); + Assert.True(schema.TryGetColumnIndex("Placeholder", out int col)); + var type = schema.GetColumnType(col).AsVector; + Assert.Equal(2, type.DimCount); + Assert.Equal(28, type.GetDim(0)); + Assert.Equal(28, type.GetDim(1)); + var metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col); + Assert.NotNull(metadataType); + Assert.True(metadataType.IsText); + ReadOnlyMemory opType = default; + schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType); + Assert.Equal("Placeholder", opType.ToString()); + metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col); + Assert.Null(metadataType); + + Assert.True(schema.TryGetColumnIndex("conv2d/Conv2D/ReadVariableOp", out col)); + type = schema.GetColumnType(col).AsVector; + Assert.Equal(4, type.DimCount); + Assert.Equal(5, type.GetDim(0)); + Assert.Equal(5, type.GetDim(1)); + Assert.Equal(1, type.GetDim(2)); + Assert.Equal(32, type.GetDim(3)); + metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col); + Assert.NotNull(metadataType); + Assert.True(metadataType.IsText); + schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType); + Assert.Equal("Identity", opType.ToString()); + metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col); + Assert.NotNull(metadataType); + VBuffer> inputOps = default; + schema.GetMetadata(TensorFlowUtils.InputOps, col, ref inputOps); + Assert.Equal(1, inputOps.Length); + Assert.Equal("conv2d/kernel", inputOps.Values[0].ToString()); + + Assert.True(schema.TryGetColumnIndex("conv2d/Conv2D", out col)); + type = schema.GetColumnType(col).AsVector; + Assert.Equal(3, type.DimCount); + Assert.Equal(28, type.GetDim(0)); + Assert.Equal(28, type.GetDim(1)); + Assert.Equal(32, type.GetDim(2)); + metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col); + Assert.NotNull(metadataType); + Assert.True(metadataType.IsText); + schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType); + Assert.Equal("Conv2D", opType.ToString()); + metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col); + Assert.NotNull(metadataType); + schema.GetMetadata(TensorFlowUtils.InputOps, col, ref inputOps); + Assert.Equal(2, inputOps.Length); + Assert.Equal("reshape/Reshape", inputOps.Values[0].ToString()); + Assert.Equal("conv2d/Conv2D/ReadVariableOp", inputOps.Values[1].ToString()); + + Assert.True(schema.TryGetColumnIndex("Softmax", out col)); + type = schema.GetColumnType(col).AsVector; + Assert.Equal(1, type.DimCount); + Assert.Equal(10, type.GetDim(0)); + metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col); + Assert.NotNull(metadataType); + Assert.True(metadataType.IsText); + schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType); + Assert.Equal("Softmax", opType.ToString()); + metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col); + Assert.NotNull(metadataType); + schema.GetMetadata(TensorFlowUtils.InputOps, col, ref inputOps); + Assert.Equal(1, inputOps.Length); + Assert.Equal("sequential/dense_1/BiasAdd", inputOps.Values[0].ToString()); + + model_location = "model_matmul/frozen_saved_model.pb"; + schema = TensorFlowUtils.GetModelSchema(env, model_location); + char name = 'a'; + for (int i = 0; i < schema.ColumnCount; i++) + { + Assert.Equal(name.ToString(), schema.GetColumnName(i)); + type = schema.GetColumnType(i).AsVector; + Assert.Equal(2, type.DimCount); + Assert.Equal(2, type.GetDim(0)); + Assert.Equal(2, type.GetDim(1)); + name++; + } + } + } + [Fact] public void TensorFlowTransformMNISTConvTest() {