Skip to content

Commit

Permalink
Updates based on PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelgsharp committed Mar 2, 2021
1 parent d045354 commit b29b60d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/Microsoft.ML.TensorFlow/TensorFlowModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ internal TensorFlowModel(IHostEnvironment env, Session session, string modelLoca
/// </summary>
public DataViewSchema GetModelSchema()
{
return TensorFlowUtils.GetModelSchema(_env, Session.graph, treatOutputAsBatched: TreatOutputAsBatched);
return TensorFlowUtils.GetModelSchema(_env, Session.graph, TreatOutputAsBatched);
}

/// <summary>
Expand All @@ -52,7 +52,7 @@ public DataViewSchema GetModelSchema()
/// </summary>
public DataViewSchema GetInputSchema()
{
return TensorFlowUtils.GetModelSchema(_env, Session.graph, "Placeholder");
return TensorFlowUtils.GetModelSchema(_env, Session.graph, TreatOutputAsBatched, "Placeholder");
}

/// <summary>
Expand Down
10 changes: 5 additions & 5 deletions src/Microsoft.ML.TensorFlow/TensorflowUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ internal static class TensorFlowUtils
/// </summary>
internal const string TensorflowUpstreamOperatorsKind = "TensorflowUpstreamOperators";

internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph graph, string opType = null, bool treatOutputAsBatched = true)
internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph graph, bool treatOutputAsBatched, string opType = null)
{
var schemaBuilder = new DataViewSchema.Builder();
foreach (Operation op in graph)
Expand Down Expand Up @@ -99,9 +99,9 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap
columnType = new VectorDataViewType(mlType, tensorShape[0] > 0 ? tensorShape : tensorShape.Skip(1).ToArray());
}
// When treatOutputAsBatched is false, if the first value is less than 0 we want to set it to 0. TensorFlow
// represents and unkown size as -1, but ML.NET represents it as 0 so we need to convert it.
//I.E. if the input dimensions are [-1, 5], ML.NET will read the -1 as a dimension of unkown length, and so the ML.NET
//data type will be a vector of 2 dimensions, where the first dimension is unkown and the second has a length of 5.
// represents an unkown size as -1, but ML.NET represents it as 0 so we need to convert it.
// I.E. if the input dimensions are [-1, 5], ML.NET will read the -1 as a dimension of unkown length, and so the ML.NET
// data type will be a vector of 2 dimensions, where the first dimension is unkown and the second has a length of 5.
else
{
if (tensorShape[0] < 0)
Expand Down Expand Up @@ -129,7 +129,7 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap
internal static DataViewSchema GetModelSchema(IHostEnvironment env, string modelPath, bool treatOutputAsBatched = true)
{
using var model = LoadTensorFlowModel(env, modelPath, treatOutputAsBatched);
return GetModelSchema(env, model.Session.graph, treatOutputAsBatched: treatOutputAsBatched);
return GetModelSchema(env, model.Session.graph, treatOutputAsBatched);
}

/// <summary>
Expand Down

0 comments on commit b29b60d

Please sign in to comment.