Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create API for extracting information about the nodes in a TensorFlow model #862

Merged
merged 12 commits into from
Sep 20, 2018

Conversation

yaeldekel
Copy link

This PR addresses issue #791 .
Please feel free to add feedback or suggestions.

using System.Collections.Generic;
using System.Linq.Expressions;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Data;
Copy link
Member

@ericstj ericstj Sep 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these using statements actually necessary? I'm missing the additions that actually used them. #Resolved

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing it out. I added the new API in this file, but ended up moving it to TensorflowUtils.


In reply to: 216493797 [](ancestors = 216493797)

@@ -700,6 +698,24 @@ public override string ToString()
IntPtr len;
return TF_GraphDebugString(Handle, out len);
}

[DllImport(NativeBinding.TensorFlowLibrary)]
internal static extern string TF_OperationName(TF_Operation oper);
Copy link
Member

@ericstj ericstj Sep 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than directly exposing the TF C-API we should consider bringing back the TF# defined API (TFOperation class) on top of it. See zeahmed@b2a8016#diff-ec7ea5716f3c05f773d3e1507b4f486aL729, zeahmed@b2a8016#diff-ec7ea5716f3c05f773d3e1507b4f486aL748. #Resolved

continue;

var numInputs = TFGraph.TF_OperationNumInputs(oper);
if (numInputs == 0)
Copy link
Contributor

@zeahmed zeahmed Sep 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numInputs [](start = 20, length = 9)

What does numInputs == 0 indicates? The input node does not have any input I believe??? #Closed

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The input has OpType "Placeholder". There are other nodes with numInputs==0, which have OpType "Const". I am not sure what they do but I think we don't want them in our output schema. What do you think?


In reply to: 216508337 [](ancestors = 216508337)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think if such nodes are not filter? Do you foresee adverse effect of doing this or not doing this?


In reply to: 216758114 [](ancestors = 216758114,216508337)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the nodes with 0 inputs to the schema as well.


In reply to: 216817673 [](ancestors = 216817673,216758114,216508337)

while ((oper = TFGraph.TF_GraphNextOperation(graph.handle, &pos)) != IntPtr.Zero)
{
var name = TFGraph.TF_OperationName(oper);
var type = TFGraph.TF_OperationOpType(oper);
Copy link
Contributor

@zeahmed zeahmed Sep 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type [](start = 20, length = 4)

Its not being used anywhere. #Closed

var model_location = "mnist_model/frozen_saved_model.pb";
using (var env = new TlcEnvironment(seed: 1, conc: 1))
{
var schema = TensorFlowUtils.GetModelSchema(env, model_location);
Copy link
Contributor

@zeahmed zeahmed Sep 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

schema [](start = 20, length = 6)

It would be nice to have schema checked against the actual model information. It seems like mnist_model/frozen_saved_model.pb is a big model. Matrix multiplication model used in above test may be a good for this test.
It can be implemented as a new test in addition to this one. #Closed

public static ISchema GetModelSchema(IExceptionContext ectx, string modelFile)
{
var bytes = File.ReadAllBytes(modelFile);
var session = LoadTFSession(ectx, bytes, modelFile);
Copy link
Contributor

@zeahmed zeahmed Sep 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LoadTFSession [](start = 26, length = 13)

What about models that are not frozen? I assume it will be there once @abgoswam changes are there, right? #Closed

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I assume that LoadTFSession will take care of the logic to decide which kind of model to load.


In reply to: 216509938 [](ancestors = 216509938)

{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(RegistrationName));
_host.CheckValue(modelBytes, nameof(modelBytes));
Session = LoadTFSession(modelBytes);
Session = TensorFlowUtils.LoadTFSession(_host, modelBytes, modelFile);
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Sep 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Abhishek is working on brining another way to load session, so I think it's better to extend CheckFileAndRead function and force it to return you Session instead of byte array. So private constructor would just accept session. #Closed

@@ -182,16 +166,16 @@ private static byte[] CheckFileAndRead(IHostEnvironment env, string modelFile)
}

public TensorFlowTransform(IHostEnvironment env, string modelFile, string[] inputs, string[] outputs) :
this(env, CheckFileAndRead(env, modelFile), inputs, outputs)
this(env, CheckFileAndRead(env, modelFile), inputs, outputs, modelFile)
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Sep 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modelFile [](start = 73, length = 9)

is it model file or model args? #Closed

Copy link
Contributor

@zeahmed zeahmed left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:shipit:

@Ivanidzo4ka
Copy link
Contributor

Ivanidzo4ka commented Sep 13, 2018

Is it still WIP or it's ok to review it properly? I see people signing off, and WIP status and find this a bit confusing #Resolved

@yaeldekel yaeldekel changed the title WIP: Create API for extracting information about the nodes in a TensorFlow model Create API for extracting information about the nodes in a TensorFlow model Sep 14, 2018
return;
}

foreach (var (name, opType, type, inputs) in TensorFlowUtils.GetModelNodes(args[0]))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

foreach [](start = 16, length = 7)

I can add more arguments to this app, to let the user filter nodes by operation type (for example, sometimes there are lots of "Const" nodes that users might not be interested in if they are just trying to find the name of a certain layer).

Is this valuable? Is having a method that returns this information enough so users can filter programatically, or would user want this as well?

@yaeldekel
Copy link
Author

Please review properly, I removed the WIP from the title.


In reply to: 421182387 [](ancestors = 421182387)

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>netcoreapp2.1</TargetFramework>
<AssemblyName>DnnAnalyzer</AssemblyName>
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it be part of TensorFlow package? #Closed

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should, thanks for reminding.


In reply to: 218189010 [](ancestors = 218189010)

{
if (Utils.Size(args) != 1)
{
ch.Error("Usage: dotnet DnnAnalyzer.dll <model_location>");
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.dll [](start = 55, length = 4)

is dll necessary? and should it be dotnet run or just dotnet works? #Closed

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"dotnet DnnAnalyzer " didn't work.
"dotnet run DnnAnalyzer.dll " didn't work.

Is there a different syntax I should try?


In reply to: 218189666 [](ancestors = 218189666)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, apparently i need to study dotnet better. dotnet run for projects only,


In reply to: 218535326 [](ancestors = 218535326,218189666)

{
public static void Main(string[] args)
{
using (var env = new TlcEnvironment())
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TlcEnvironment [](start = 33, length = 14)

isn't it feel weird to create environment only to write something to console?
Why can't you just do Console.Writeline? #Resolved

{
}

private TensorFlowTransform(IHostEnvironment env, byte[] modelBytes, string[] inputs, string[] outputs)
private TensorFlowTransform(IHostEnvironment env, TFSession session, string[] inputs, string[] outputs, string modelFile = null)
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

string modelFile = null [](start = 112, length = 23)

not necessary anymore #Closed

var opTypeGetters = new List<MetadataUtils.MetadataGetter<DvText>>();
var inputOpsGetters = new List<MetadataUtils.MetadataGetter<VBuffer<DvText>>>();
var inputOpsLengths = new List<int>();
foreach (var oper in graph)
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oper [](start = 25, length = 4)

nit: If you shorten operation to oper, you can just go even further to "op" option.
#Closed

var inputOpsLengths = new List<int>();
foreach (var oper in graph)
{
if (oper.NumOutputs != 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (oper.NumOutputs != 1) [](start = 16, length = 25)

I think this deserve comment.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this code. I think that if a node has multiple outputs it means that the output it produces is used as input to multiple nodes, but the shape and type of this output will be the same for every node that uses it as input. In this case there is no need to skip it. @zeahmed, is my assumption correct?


In reply to: 218204247 [](ancestors = 218204247)

Copy link
Contributor

@zeahmed zeahmed Sep 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, tf.Operation can produced multiple outputs with different types and shapes. In case of single output used by multiple nodes, I am not sure if `oper.NumOutputs > 1' in that case. Technically, recurrent layers can produced two outputs (hidden state, output) but I would need to see how the recurrent ops are implemented in graphs.


In reply to: 218551502 [](ancestors = 218551502,218204247)

@@ -25,7 +30,92 @@ public static void Initialize()
ImageAnalytics.Initialize();
}

private static unsafe ISchema GetModelSchema(IExceptionContext ectx, TFGraph graph)
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unsafe [](start = 23, length = 6)

just curious, why it's unsafe? #Closed

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing it out, it doesn't need to be. In a previous iteration I was directly running the while loop over the graph that is now in the Graph class, so it was unsafe.


In reply to: 218204434 [](ancestors = 218204434)

Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕐

Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:shipit:

@yaeldekel yaeldekel merged commit a627d5b into dotnet:master Sep 20, 2018
@yaeldekel yaeldekel deleted the shapesapi branch September 20, 2018 17:22
@ghost ghost locked as resolved and limited conversation to collaborators Mar 29, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants