Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create API for extracting information about the nodes in a TensorFlow model #862

Merged
merged 12 commits into from
Sep 20, 2018
11 changes: 11 additions & 0 deletions Microsoft.ML.sln
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Analyzer", "sr
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.StaticPipelineTesting", "test\Microsoft.ML.StaticPipelineTesting\Microsoft.ML.StaticPipelineTesting.csproj", "{8B38BF24-35F4-4787-A9C5-22D35987106E}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.DnnAnalyzer", "src\Microsoft.ML.DnnAnalyzer\Microsoft.ML.DnnAnalyzer\Microsoft.ML.DnnAnalyzer.csproj", "{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -419,6 +421,14 @@ Global
{8B38BF24-35F4-4787-A9C5-22D35987106E}.Release|Any CPU.Build.0 = Release|Any CPU
{8B38BF24-35F4-4787-A9C5-22D35987106E}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
{8B38BF24-35F4-4787-A9C5-22D35987106E}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Debug|Any CPU.Build.0 = Debug|Any CPU
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Release|Any CPU.ActiveCfg = Release|Any CPU
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Release|Any CPU.Build.0 = Release|Any CPU
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -466,6 +476,7 @@ Global
{570A0B8A-5463-44D2-8521-54C0CA4CACA9} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{6DEF0F40-3853-47B3-8165-5F24BA5E14DF} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{8B38BF24-35F4-4787-A9C5-22D35987106E} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}
Expand Down
126 changes: 82 additions & 44 deletions src/Microsoft.ML.Data/DataView/SimpleRow.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,97 +64,135 @@ public bool IsColumnActive(int col)
/// An <see cref="ISchema"/> that takes all column names and types as constructor parameters.
/// The columns do not have metadata.
/// </summary>
public sealed class SimpleSchema : ISchema
public abstract class SimpleSchemaBase : ISchema
{
private readonly IExceptionContext _ectx;
protected readonly IExceptionContext Ectx;
private readonly string[] _names;
private readonly ColumnType[] _types;
private readonly Dictionary<string, int> _columnNameMap;
private readonly MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>[] _keyValueGetters;
protected readonly ColumnType[] Types;
protected readonly Dictionary<string, int> ColumnNameMap;

public int ColumnCount => _types.Length;
public int ColumnCount => Types.Length;

public SimpleSchema(IExceptionContext ectx, params KeyValuePair<string, ColumnType>[] columns)
protected SimpleSchemaBase(IExceptionContext ectx, params KeyValuePair<string, ColumnType>[] columns)
{
Contracts.CheckValueOrNull(ectx);
_ectx = ectx;
_ectx.CheckValue(columns, nameof(columns));
Ectx = ectx;
Ectx.CheckValue(columns, nameof(columns));

_names = new string[columns.Length];
_types = new ColumnType[columns.Length];
_columnNameMap = new Dictionary<string, int>();
Types = new ColumnType[columns.Length];
ColumnNameMap = new Dictionary<string, int>();
for (int i = 0; i < columns.Length; i++)
{
_names[i] = columns[i].Key;
_types[i] = columns[i].Value;
if (_columnNameMap.ContainsKey(columns[i].Key))
Types[i] = columns[i].Value;
if (ColumnNameMap.ContainsKey(columns[i].Key))
throw ectx.ExceptParam(nameof(columns), $"Duplicate column name: '{columns[i].Key}'");
_columnNameMap[columns[i].Key] = i;
}
_keyValueGetters = new MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>[ColumnCount];
}

public SimpleSchema(IExceptionContext ectx, KeyValuePair<string, ColumnType>[] columns, Dictionary<string, MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>> keyValues)
: this(ectx, columns)
{
foreach (var kvp in keyValues)
{
var name = kvp.Key;
var getter = kvp.Value;
if (!_columnNameMap.TryGetValue(name, out int col))
throw _ectx.ExceptParam(nameof(keyValues), $"Output schema does not contain column '{name}'");
if (!_types[col].ItemType.IsKey)
throw _ectx.ExceptParam(nameof(keyValues), $"Column '{name}' is not a key column, so it cannot have key value metadata");
_keyValueGetters[col] = getter;
ColumnNameMap[columns[i].Key] = i;
}
}

public bool TryGetColumnIndex(string name, out int col)
{
return _columnNameMap.TryGetValue(name, out col);
return ColumnNameMap.TryGetValue(name, out col);
}

public string GetColumnName(int col)
{
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
return _names[col];
}

public ColumnType GetColumnType(int col)
{
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
return _types[col];
Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
return Types[col];
}

public IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypes(int col)
{
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
Ectx.Assert(0 <= col && col < ColumnCount);
return GetMetadataTypesCore(col);
}

protected abstract IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypesCore(int col);

public ColumnType GetMetadataTypeOrNull(string kind, int col)
{
Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
return GetMetadataTypeOrNullCore(kind, col);
}

protected abstract ColumnType GetMetadataTypeOrNullCore(string kind, int col);

public void GetMetadata<TValue>(string kind, int col, ref TValue value)
{
Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
GetMetadataCore(kind, col, ref value);
}

protected abstract void GetMetadataCore<TValue>(string kind, int col, ref TValue value);
}

/// <summary>
/// An <see cref="ISchema"/> that takes all column names and types as constructor parameters.
/// The columns can optionally have text <see cref="MetadataUtils.Kinds.KeyValues"/> metadata.
/// </summary>
public sealed class SimpleSchema : SimpleSchemaBase
{
private readonly MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>[] _keyValueGetters;

public SimpleSchema(IExceptionContext ectx, params KeyValuePair<string, ColumnType>[] columns)
: base(ectx, columns)
{
_keyValueGetters = new MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>[ColumnCount];
}

public SimpleSchema(IExceptionContext ectx, KeyValuePair<string, ColumnType>[] columns,
Dictionary<string, MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>> keyValues)
: this(ectx, columns)
{
foreach (var kvp in keyValues)
{
var name = kvp.Key;
var getter = kvp.Value;
if (!ColumnNameMap.TryGetValue(name, out int col))
throw Ectx.ExceptParam(nameof(keyValues), $"Output schema does not contain column '{name}'");
if (!Types[col].ItemType.IsKey)
throw Ectx.ExceptParam(nameof(keyValues), $"Column '{name}' is not a key column, so it cannot have key value metadata");
_keyValueGetters[col] = getter;
}
}

protected override IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypesCore(int col)
{
Ectx.Assert(0 <= col && col < ColumnCount);
if (_keyValueGetters[col] != null)
{
_ectx.Assert(_types[col].ItemType.IsKey);
Ectx.Assert(Types[col].ItemType.IsKey);
yield return new KeyValuePair<string, ColumnType>(MetadataUtils.Kinds.KeyValues,
new VectorType(TextType.Instance, _types[col].ItemType.KeyCount));
new VectorType(TextType.Instance, Types[col].ItemType.KeyCount));
}
}

public ColumnType GetMetadataTypeOrNull(string kind, int col)
protected override ColumnType GetMetadataTypeOrNullCore(string kind, int col)
{
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
Ectx.Assert(0 <= col && col < ColumnCount);
if (kind == MetadataUtils.Kinds.KeyValues && _keyValueGetters[col] != null)
{
_ectx.Assert(_types[col].ItemType.IsKey);
return new VectorType(TextType.Instance, _types[col].ItemType.KeyCount);
Ectx.Assert(Types[col].ItemType.IsKey);
return new VectorType(TextType.Instance, Types[col].ItemType.KeyCount);
}
return null;
}

public void GetMetadata<TValue>(string kind, int col, ref TValue value)
protected override void GetMetadataCore<TValue>(string kind, int col, ref TValue value)
{
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
Ectx.Assert(0 <= col && col < ColumnCount);
if (kind == MetadataUtils.Kinds.KeyValues && _keyValueGetters[col] != null)
_keyValueGetters[col].Marshal(col, ref value);
else
throw _ectx.ExceptGetMetadata();
throw Ectx.ExceptGetMetadata();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Transforms.TensorFlow;
using System;
using System.Linq;

namespace Microsoft.ML.DnnAnalyzer
{
public static class DnnAnalyzer
{
public static void Main(string[] args)
{
if (Utils.Size(args) != 1)
{
Console.Error.WriteLine("Usage: dotnet DnnAnalyzer.dll <model_location>");
return;
}

foreach (var (name, opType, type, inputs) in TensorFlowUtils.GetModelNodes(args[0]))
{
var inputsString = inputs.Length == 0 ? "" : $", input nodes: {string.Join(", ", inputs)}";
Console.WriteLine($"Graph node: '{name}', operation type: '{opType}', output type: '{type}'{inputsString}");
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>netcoreapp2.1</TargetFramework>
<AssemblyName>DnnAnalyzer</AssemblyName>
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it be part of TensorFlow package? #Closed

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should, thanks for reminding.


In reply to: 218189010 [](ancestors = 218189010)

<IncludeInPackage>Microsoft.ML.TensorFlow</IncludeInPackage>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
<ProjectReference Include="..\..\Microsoft.ML.TensorFlow\Microsoft.ML.TensorFlow.csproj" />
</ItemGroup>

<ItemGroup>
<NativeAssemblyReference Include="tensorflow" />
</ItemGroup>

</Project>
81 changes: 71 additions & 10 deletions src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

using size_t = System.UIntPtr;
using System.Collections.Generic;
using System.Collections;

#pragma warning disable MSML_GeneralName
#pragma warning disable MSML_PrivateFieldName
Expand Down Expand Up @@ -492,7 +493,7 @@ public void SetConfig(IntPtr protoData, int length, TFStatus status = null)
/// "hot", and add a "sub" operation there the result will be "demo/hot/sub".
/// </para>
/// </remarks>
internal partial class TFGraph : TFDisposableThreadSafe
internal partial class TFGraph : TFDisposableThreadSafe, IEnumerable<TFOperation>
{
// extern TF_Graph * TF_NewGraph ();
[DllImport(NativeBinding.TensorFlowLibrary)]
Expand Down Expand Up @@ -696,6 +697,33 @@ public override string ToString()
IntPtr len;
return TF_GraphDebugString(Handle, out len);
}

[DllImport(NativeBinding.TensorFlowLibrary)]
private static unsafe extern TF_Operation TF_GraphNextOperation(TF_Graph graph, ref IntPtr pos);

/// <summary>
/// Returns the enumerator that returns all the TFOperations in a graph.
/// </summary>
/// <returns>The enumerator.</returns>
private IEnumerable<TFOperation> GetEnumerable()
{
if (handle == IntPtr.Zero)
ObjectDisposedException();
IntPtr token = IntPtr.Zero;
IntPtr operll;
while ((operll = TF_GraphNextOperation(handle, ref token)) != IntPtr.Zero)
yield return new TFOperation(this, operll);
}

public IEnumerator<TFOperation> GetEnumerator()
{
return GetEnumerable().GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
}

/// <summary>
Expand Down Expand Up @@ -736,6 +764,48 @@ public TFOutput this[int idx]
return new TFOutput(this, idx);
}
}

// extern TF_Output TF_OperationInput (TF_Input oper_in);
[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern TFOutput TF_OperationInput(TFInput oper_in);

public TFOutput GetInput(int idx)
{
return TF_OperationInput(new TFInput() { Operation = handle, Index = idx });
}

[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern IntPtr TF_OperationName(TF_Operation oper);

/// <summary>
/// The name for this operation/
/// </summary>
/// <value>The name.</value>
public string Name => handle == IntPtr.Zero ? "<ObjectDisposed>" : TF_OperationName(handle).GetStr();

[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern IntPtr TF_OperationOpType(TF_Operation oper);

public string OpType => handle == IntPtr.Zero ? "<ObjectDisposedException>" : TF_OperationOpType(handle).GetStr();

[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern int TF_OperationNumOutputs(TF_Operation oper);

/// <summary>
/// Gets the number of outputs on this operation.
/// </summary>
/// <value>The number outputs.</value>
public int NumOutputs => handle == IntPtr.Zero ? -1 : TF_OperationNumOutputs(handle);

[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern int TF_OperationNumInputs(TF_Operation oper);

/// <summary>
/// Gets the number of inputs for this operation.
/// Import a serialized graph into this graph, using the specified importing options.
/// </summary>
/// <value>The number inputs.</value>
public int NumInputs => TF_OperationNumInputs(handle);
}

/// <summary>
Expand Down Expand Up @@ -1768,15 +1838,6 @@ internal struct TFInput
/// </summary>
public int Index;

// extern TF_Output TF_OperationInput (TF_Input oper_in);
[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern TFOutput TF_OperationInput(TFInput oper_in);

public TFOutput GetOutput(TFInput operIn)
{
return TF_OperationInput(operIn);
}

// extern TF_DataType TF_OperationInputType (TF_Input oper_in);
[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern TFDataType TF_OperationInputType(TFInput oper_in);
Expand Down
Loading