Skip to content

Commit

Permalink
Create API for extracting information about the nodes in a TensorFlow…
Browse files Browse the repository at this point in the history
… model (#862)

* Add a method that returns TensorFlow model outputs as an ISchema.

* Update after merge with master

* Address PR comments.

* Add metadata with information about the operation type, and the inputs needed for it.

* Add method that returns an enumerable of the information about graph nodes, and a console app that displays it

* Add the DnnAnalyzer project files.

* Address code review comments

* Make needed changes after merge with master

* Fix bug when there is a node with 1 dimension that is unknown
  • Loading branch information
yaeldMS authored Sep 20, 2018
1 parent 86f4d93 commit a627d5b
Show file tree
Hide file tree
Showing 8 changed files with 476 additions and 79 deletions.
11 changes: 11 additions & 0 deletions Microsoft.ML.sln
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Analyzer", "sr
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.StaticPipelineTesting", "test\Microsoft.ML.StaticPipelineTesting\Microsoft.ML.StaticPipelineTesting.csproj", "{8B38BF24-35F4-4787-A9C5-22D35987106E}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.DnnAnalyzer", "src\Microsoft.ML.DnnAnalyzer\Microsoft.ML.DnnAnalyzer\Microsoft.ML.DnnAnalyzer.csproj", "{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -419,6 +421,14 @@ Global
{8B38BF24-35F4-4787-A9C5-22D35987106E}.Release|Any CPU.Build.0 = Release|Any CPU
{8B38BF24-35F4-4787-A9C5-22D35987106E}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
{8B38BF24-35F4-4787-A9C5-22D35987106E}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Debug|Any CPU.Build.0 = Debug|Any CPU
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Release|Any CPU.ActiveCfg = Release|Any CPU
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Release|Any CPU.Build.0 = Release|Any CPU
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -466,6 +476,7 @@ Global
{570A0B8A-5463-44D2-8521-54C0CA4CACA9} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{6DEF0F40-3853-47B3-8165-5F24BA5E14DF} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{8B38BF24-35F4-4787-A9C5-22D35987106E} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}
Expand Down
126 changes: 82 additions & 44 deletions src/Microsoft.ML.Data/DataView/SimpleRow.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,97 +64,135 @@ public bool IsColumnActive(int col)
/// An <see cref="ISchema"/> that takes all column names and types as constructor parameters.
/// The columns do not have metadata.
/// </summary>
public sealed class SimpleSchema : ISchema
public abstract class SimpleSchemaBase : ISchema
{
private readonly IExceptionContext _ectx;
protected readonly IExceptionContext Ectx;
private readonly string[] _names;
private readonly ColumnType[] _types;
private readonly Dictionary<string, int> _columnNameMap;
private readonly MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>[] _keyValueGetters;
protected readonly ColumnType[] Types;
protected readonly Dictionary<string, int> ColumnNameMap;

public int ColumnCount => _types.Length;
public int ColumnCount => Types.Length;

public SimpleSchema(IExceptionContext ectx, params KeyValuePair<string, ColumnType>[] columns)
protected SimpleSchemaBase(IExceptionContext ectx, params KeyValuePair<string, ColumnType>[] columns)
{
Contracts.CheckValueOrNull(ectx);
_ectx = ectx;
_ectx.CheckValue(columns, nameof(columns));
Ectx = ectx;
Ectx.CheckValue(columns, nameof(columns));

_names = new string[columns.Length];
_types = new ColumnType[columns.Length];
_columnNameMap = new Dictionary<string, int>();
Types = new ColumnType[columns.Length];
ColumnNameMap = new Dictionary<string, int>();
for (int i = 0; i < columns.Length; i++)
{
_names[i] = columns[i].Key;
_types[i] = columns[i].Value;
if (_columnNameMap.ContainsKey(columns[i].Key))
Types[i] = columns[i].Value;
if (ColumnNameMap.ContainsKey(columns[i].Key))
throw ectx.ExceptParam(nameof(columns), $"Duplicate column name: '{columns[i].Key}'");
_columnNameMap[columns[i].Key] = i;
}
_keyValueGetters = new MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>[ColumnCount];
}

public SimpleSchema(IExceptionContext ectx, KeyValuePair<string, ColumnType>[] columns, Dictionary<string, MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>> keyValues)
: this(ectx, columns)
{
foreach (var kvp in keyValues)
{
var name = kvp.Key;
var getter = kvp.Value;
if (!_columnNameMap.TryGetValue(name, out int col))
throw _ectx.ExceptParam(nameof(keyValues), $"Output schema does not contain column '{name}'");
if (!_types[col].ItemType.IsKey)
throw _ectx.ExceptParam(nameof(keyValues), $"Column '{name}' is not a key column, so it cannot have key value metadata");
_keyValueGetters[col] = getter;
ColumnNameMap[columns[i].Key] = i;
}
}

public bool TryGetColumnIndex(string name, out int col)
{
return _columnNameMap.TryGetValue(name, out col);
return ColumnNameMap.TryGetValue(name, out col);
}

public string GetColumnName(int col)
{
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
return _names[col];
}

public ColumnType GetColumnType(int col)
{
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
return _types[col];
Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
return Types[col];
}

public IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypes(int col)
{
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
Ectx.Assert(0 <= col && col < ColumnCount);
return GetMetadataTypesCore(col);
}

protected abstract IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypesCore(int col);

public ColumnType GetMetadataTypeOrNull(string kind, int col)
{
Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
return GetMetadataTypeOrNullCore(kind, col);
}

protected abstract ColumnType GetMetadataTypeOrNullCore(string kind, int col);

public void GetMetadata<TValue>(string kind, int col, ref TValue value)
{
Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
GetMetadataCore(kind, col, ref value);
}

protected abstract void GetMetadataCore<TValue>(string kind, int col, ref TValue value);
}

/// <summary>
/// An <see cref="ISchema"/> that takes all column names and types as constructor parameters.
/// The columns can optionally have text <see cref="MetadataUtils.Kinds.KeyValues"/> metadata.
/// </summary>
public sealed class SimpleSchema : SimpleSchemaBase
{
private readonly MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>[] _keyValueGetters;

public SimpleSchema(IExceptionContext ectx, params KeyValuePair<string, ColumnType>[] columns)
: base(ectx, columns)
{
_keyValueGetters = new MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>[ColumnCount];
}

public SimpleSchema(IExceptionContext ectx, KeyValuePair<string, ColumnType>[] columns,
Dictionary<string, MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>> keyValues)
: this(ectx, columns)
{
foreach (var kvp in keyValues)
{
var name = kvp.Key;
var getter = kvp.Value;
if (!ColumnNameMap.TryGetValue(name, out int col))
throw Ectx.ExceptParam(nameof(keyValues), $"Output schema does not contain column '{name}'");
if (!Types[col].ItemType.IsKey)
throw Ectx.ExceptParam(nameof(keyValues), $"Column '{name}' is not a key column, so it cannot have key value metadata");
_keyValueGetters[col] = getter;
}
}

protected override IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypesCore(int col)
{
Ectx.Assert(0 <= col && col < ColumnCount);
if (_keyValueGetters[col] != null)
{
_ectx.Assert(_types[col].ItemType.IsKey);
Ectx.Assert(Types[col].ItemType.IsKey);
yield return new KeyValuePair<string, ColumnType>(MetadataUtils.Kinds.KeyValues,
new VectorType(TextType.Instance, _types[col].ItemType.KeyCount));
new VectorType(TextType.Instance, Types[col].ItemType.KeyCount));
}
}

public ColumnType GetMetadataTypeOrNull(string kind, int col)
protected override ColumnType GetMetadataTypeOrNullCore(string kind, int col)
{
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
Ectx.Assert(0 <= col && col < ColumnCount);
if (kind == MetadataUtils.Kinds.KeyValues && _keyValueGetters[col] != null)
{
_ectx.Assert(_types[col].ItemType.IsKey);
return new VectorType(TextType.Instance, _types[col].ItemType.KeyCount);
Ectx.Assert(Types[col].ItemType.IsKey);
return new VectorType(TextType.Instance, Types[col].ItemType.KeyCount);
}
return null;
}

public void GetMetadata<TValue>(string kind, int col, ref TValue value)
protected override void GetMetadataCore<TValue>(string kind, int col, ref TValue value)
{
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
Ectx.Assert(0 <= col && col < ColumnCount);
if (kind == MetadataUtils.Kinds.KeyValues && _keyValueGetters[col] != null)
_keyValueGetters[col].Marshal(col, ref value);
else
throw _ectx.ExceptGetMetadata();
throw Ectx.ExceptGetMetadata();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Transforms.TensorFlow;
using System;
using System.Linq;

namespace Microsoft.ML.DnnAnalyzer
{
public static class DnnAnalyzer
{
public static void Main(string[] args)
{
if (Utils.Size(args) != 1)
{
Console.Error.WriteLine("Usage: dotnet DnnAnalyzer.dll <model_location>");
return;
}

foreach (var (name, opType, type, inputs) in TensorFlowUtils.GetModelNodes(args[0]))
{
var inputsString = inputs.Length == 0 ? "" : $", input nodes: {string.Join(", ", inputs)}";
Console.WriteLine($"Graph node: '{name}', operation type: '{opType}', output type: '{type}'{inputsString}");
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>netcoreapp2.1</TargetFramework>
<AssemblyName>DnnAnalyzer</AssemblyName>
<IncludeInPackage>Microsoft.ML.TensorFlow</IncludeInPackage>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
<ProjectReference Include="..\..\Microsoft.ML.TensorFlow\Microsoft.ML.TensorFlow.csproj" />
</ItemGroup>

<ItemGroup>
<NativeAssemblyReference Include="tensorflow" />
</ItemGroup>

</Project>
81 changes: 71 additions & 10 deletions src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

using size_t = System.UIntPtr;
using System.Collections.Generic;
using System.Collections;

#pragma warning disable MSML_GeneralName
#pragma warning disable MSML_PrivateFieldName
Expand Down Expand Up @@ -492,7 +493,7 @@ public void SetConfig(IntPtr protoData, int length, TFStatus status = null)
/// "hot", and add a "sub" operation there the result will be "demo/hot/sub".
/// </para>
/// </remarks>
internal partial class TFGraph : TFDisposableThreadSafe
internal partial class TFGraph : TFDisposableThreadSafe, IEnumerable<TFOperation>
{
// extern TF_Graph * TF_NewGraph ();
[DllImport(NativeBinding.TensorFlowLibrary)]
Expand Down Expand Up @@ -696,6 +697,33 @@ public override string ToString()
IntPtr len;
return TF_GraphDebugString(Handle, out len);
}

[DllImport(NativeBinding.TensorFlowLibrary)]
private static unsafe extern TF_Operation TF_GraphNextOperation(TF_Graph graph, ref IntPtr pos);

/// <summary>
/// Returns the enumerator that returns all the TFOperations in a graph.
/// </summary>
/// <returns>The enumerator.</returns>
private IEnumerable<TFOperation> GetEnumerable()
{
if (handle == IntPtr.Zero)
ObjectDisposedException();
IntPtr token = IntPtr.Zero;
IntPtr operll;
while ((operll = TF_GraphNextOperation(handle, ref token)) != IntPtr.Zero)
yield return new TFOperation(this, operll);
}

public IEnumerator<TFOperation> GetEnumerator()
{
return GetEnumerable().GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
}

/// <summary>
Expand Down Expand Up @@ -736,6 +764,48 @@ public TFOutput this[int idx]
return new TFOutput(this, idx);
}
}

// extern TF_Output TF_OperationInput (TF_Input oper_in);
[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern TFOutput TF_OperationInput(TFInput oper_in);

public TFOutput GetInput(int idx)
{
return TF_OperationInput(new TFInput() { Operation = handle, Index = idx });
}

[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern IntPtr TF_OperationName(TF_Operation oper);

/// <summary>
/// The name for this operation/
/// </summary>
/// <value>The name.</value>
public string Name => handle == IntPtr.Zero ? "<ObjectDisposed>" : TF_OperationName(handle).GetStr();

[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern IntPtr TF_OperationOpType(TF_Operation oper);

public string OpType => handle == IntPtr.Zero ? "<ObjectDisposedException>" : TF_OperationOpType(handle).GetStr();

[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern int TF_OperationNumOutputs(TF_Operation oper);

/// <summary>
/// Gets the number of outputs on this operation.
/// </summary>
/// <value>The number outputs.</value>
public int NumOutputs => handle == IntPtr.Zero ? -1 : TF_OperationNumOutputs(handle);

[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern int TF_OperationNumInputs(TF_Operation oper);

/// <summary>
/// Gets the number of inputs for this operation.
/// Import a serialized graph into this graph, using the specified importing options.
/// </summary>
/// <value>The number inputs.</value>
public int NumInputs => TF_OperationNumInputs(handle);
}

/// <summary>
Expand Down Expand Up @@ -1768,15 +1838,6 @@ internal struct TFInput
/// </summary>
public int Index;

// extern TF_Output TF_OperationInput (TF_Input oper_in);
[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern TFOutput TF_OperationInput(TFInput oper_in);

public TFOutput GetOutput(TFInput operIn)
{
return TF_OperationInput(operIn);
}

// extern TF_DataType TF_OperationInputType (TF_Input oper_in);
[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern TFDataType TF_OperationInputType(TFInput oper_in);
Expand Down
Loading

0 comments on commit a627d5b

Please sign in to comment.