Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Treat TensorFlow output as non-batched. #5634

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/Microsoft.ML.TensorFlow/TensorFlowModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public sealed class TensorFlowModel : IDisposable
{
internal Session Session { get; }
internal string ModelPath { get; }
internal bool TreatOutputAsBatched { get; }

private readonly IHostEnvironment _env;

Expand All @@ -27,10 +28,12 @@ public sealed class TensorFlowModel : IDisposable
/// <param name="env">An <see cref="IHostEnvironment"/> object.</param>
/// <param name="session">TensorFlow session object.</param>
/// <param name="modelLocation">Location of the model from where <paramref name="session"/> was loaded.</param>
internal TensorFlowModel(IHostEnvironment env, Session session, string modelLocation)
/// <param name="treatOutputAsBatched">If the first dimension of the output is unknown, should it be treated as batched or not.</param>
internal TensorFlowModel(IHostEnvironment env, Session session, string modelLocation, bool treatOutputAsBatched = true)
michaelgsharp marked this conversation as resolved.
Show resolved Hide resolved
{
Session = session;
ModelPath = modelLocation;
TreatOutputAsBatched = treatOutputAsBatched;
_env = env;
_disposed = false;
}
Expand All @@ -40,7 +43,7 @@ internal TensorFlowModel(IHostEnvironment env, Session session, string modelLoca
/// </summary>
public DataViewSchema GetModelSchema()
{
return TensorFlowUtils.GetModelSchema(_env, Session.graph);
return TensorFlowUtils.GetModelSchema(_env, Session.graph, treatOutputAsBatched: TreatOutputAsBatched);
michaelgsharp marked this conversation as resolved.
Show resolved Hide resolved
}

/// <summary>
Expand Down
26 changes: 26 additions & 0 deletions src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,31 @@ public static class TensorflowCatalog
/// </example>
public static TensorFlowModel LoadTensorFlowModel(this ModelOperationsCatalog catalog, string modelLocation)
=> TensorFlowUtils.LoadTensorFlowModel(CatalogUtils.GetEnvironment(catalog), modelLocation);

/// <summary>
/// Load TensorFlow model into memory. This is the convenience method that allows the model to be loaded once and subsequently use it for querying schema and creation of
/// <see cref="TensorFlowEstimator"/> using <see cref="TensorFlowModel.ScoreTensorFlowModel(string, string, bool)"/>.
/// usage of this API requires additional NuGet dependencies on TensorFlow redist, see linked document for more information.
/// <see cref="TensorFlowModel"/> also holds references to unmanaged resources that need to be freed either with an explicit
/// call to Dispose() or implicitly by declaring the variable with the "using" syntax/>
///
/// <format type="text/markdown">
/// <![CDATA[
/// [!include[io](~/../docs/samples/docs/api-reference/tensorflow-usage.md)]
/// ]]>
/// </format>
/// </summary>
/// <param name="catalog">The transform's catalog.</param>
/// <param name="modelLocation">Location of the TensorFlow model.</param>
/// <param name="treatOutputAsBatched">If the first dimension of the output is unknown, should it be treated as batched or not.</param>
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// [!code-csharp[LoadTensorFlowModel](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlow/TextClassification.cs)]
/// ]]>
/// </format>
/// </example>
public static TensorFlowModel LoadTensorFlowModel(this ModelOperationsCatalog catalog, string modelLocation, bool treatOutputAsBatched)
michaelgsharp marked this conversation as resolved.
Show resolved Hide resolved
=> TensorFlowUtils.LoadTensorFlowModel(CatalogUtils.GetEnvironment(catalog), modelLocation, treatOutputAsBatched);
michaelgsharp marked this conversation as resolved.
Show resolved Hide resolved
}
}
48 changes: 33 additions & 15 deletions src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,27 +82,28 @@ private static VersionInfo GetVersionInfo()
/// Transform for scoring Tensorflow models. Input data column names/types must exactly match
/// all model input names. Only the output columns specified will be generated.
/// This convenience method avoids reloading of TensorFlow model.
/// It is useful in a situation where user has already loaded TensorFlow model using <see cref="TensorFlowUtils.LoadTensorFlowModel(IHostEnvironment, string)"/> for inspecting model schema.
/// It is useful in a situation where user has already loaded TensorFlow model using <see cref="TensorFlowUtils.LoadTensorFlowModel(IHostEnvironment, string, bool)"/> for inspecting model schema.
/// </summary>
/// <param name="env">The environment to use.</param>
/// <param name="tfModelInfo"> <see cref="TensorFlowModel"/> object created with <see cref="TensorFlowUtils.LoadTensorFlowModel(IHostEnvironment, string)"/>.</param>
/// <param name="tfModelInfo"> <see cref="TensorFlowModel"/> object created with <see cref="TensorFlowUtils.LoadTensorFlowModel(IHostEnvironment, string, bool)"/>.</param>
/// <param name="outputColumnName">The output columns to generate. Names must match model specifications. Data types are inferred from model.</param>
/// <param name="inputColumnName">The name of the input data columns. Must match model's input names. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
/// <param name="addBatchDimensionInput">Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3].
/// This parameter is used to deal with models that have unknown shape but the internal operators in the model require data to have batch dimension as well.</param>
internal TensorFlowTransformer(IHostEnvironment env, TensorFlowModel tfModelInfo, string outputColumnName, string inputColumnName = null, bool addBatchDimensionInput = false)
: this(env, tfModelInfo.Session, new[] { outputColumnName }, new[] { inputColumnName ?? outputColumnName }, IsSavedModel(env, tfModelInfo.ModelPath) ? tfModelInfo.ModelPath : null, false, addBatchDimensionInput)
/// <param name="treatOutputAsBatched">If the first dimension of the output is unknown, should it be treated as batched or not.</param>
internal TensorFlowTransformer(IHostEnvironment env, TensorFlowModel tfModelInfo, string outputColumnName, string inputColumnName = null, bool addBatchDimensionInput = false, bool treatOutputAsBatched = true)
: this(env, tfModelInfo.Session, new[] { outputColumnName }, new[] { inputColumnName ?? outputColumnName }, IsSavedModel(env, tfModelInfo.ModelPath) ? tfModelInfo.ModelPath : null, false, addBatchDimensionInput, treatOutputAsBatched: treatOutputAsBatched)
{
}

/// <summary>
/// Transform for scoring Tensorflow models. Input data column names/types must exactly match
/// all model input names. Only the output columns specified will be generated.
/// This convenience method avoids reloading of TensorFlow model.
/// It is useful in a situation where user has already loaded TensorFlow model using <see cref="TensorFlowUtils.LoadTensorFlowModel(IHostEnvironment, string)"/> for inspecting model schema.
/// It is useful in a situation where user has already loaded TensorFlow model using <see cref="TensorFlowUtils.LoadTensorFlowModel(IHostEnvironment, string, bool)"/> for inspecting model schema.
/// </summary>
/// <param name="env">The environment to use.</param>
/// <param name="tfModelInfo"> <see cref="TensorFlowModel"/> object created with <see cref="TensorFlowUtils.LoadTensorFlowModel(IHostEnvironment, string)"/>.</param>
/// <param name="tfModelInfo"> <see cref="TensorFlowModel"/> object created with <see cref="TensorFlowUtils.LoadTensorFlowModel(IHostEnvironment, string, bool)"/>.</param>
/// <param name="inputColumnNames">The name of the input data columns. Must match model's input names.</param>
/// <param name="outputColumnNames">The output columns to generate. Names must match model specifications. Data types are inferred from model.</param>
/// <param name="addBatchDimensionInput">Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3].
Expand Down Expand Up @@ -267,7 +268,7 @@ private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out

internal TensorFlowTransformer(IHostEnvironment env, Session session, string[] outputColumnNames,
string[] inputColumnNames, string savedModelPath, bool isTemporarySavedModel,
bool addBatchDimensionInput, int batchSize = 1, TensorFlowEstimator.Options options = null, IDataView input = null)
bool addBatchDimensionInput, int batchSize = 1, TensorFlowEstimator.Options options = null, IDataView input = null, bool treatOutputAsBatched = true)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TensorFlowTransformer)))

{
Expand All @@ -283,7 +284,7 @@ internal TensorFlowTransformer(IHostEnvironment env, Session session, string[] o
Outputs = outputColumnNames;
tf.compat.v1.disable_eager_execution();

(TFOutputTypes, OutputTypes, TFOutputOperations) = GetOutputInfo(Host, Session, Outputs);
(TFOutputTypes, OutputTypes, TFOutputOperations) = GetOutputInfo(Host, Session, Outputs, treatOutputAsBatched);
(TFInputTypes, TFInputShapes, TFInputOperations) = GetInputInfo(Host, Session, Inputs, batchSize);

TFInputNodes = new TF_Output[Inputs.Length];
Expand Down Expand Up @@ -359,7 +360,7 @@ internal static TensorShape GetTensorShape(TF_Output output, Graph graph, Status
return new TensorShape(dims.Select(x => (int)x).ToArray());
}

internal static (TF_DataType[] tfOutputTypes, DataViewType[] outputTypes, (Operation, int)[]) GetOutputInfo(IHost host, Session session, string[] outputs)
internal static (TF_DataType[] tfOutputTypes, DataViewType[] outputTypes, (Operation, int)[]) GetOutputInfo(IHost host, Session session, string[] outputs, bool treatOutputAsBatched)
{
var tfOutputTypes = new TF_DataType[outputs.Length];
var outputTypes = new DataViewType[outputs.Length];
Expand All @@ -384,7 +385,12 @@ internal static (TF_DataType[] tfOutputTypes, DataViewType[] outputTypes, (Opera
// If there are other dimension that are unknown the transformer will return a variable length vector.
// This is the work around in absence of reshape transformer.
var idims = shape.dims;
int[] dims = shape.ndim > 0 ? idims.Skip(idims[0] == -1 ? 1 : 0).ToArray() : new int[0];

int[] dims = idims;
if (treatOutputAsBatched)
{
dims = shape.ndim > 0 ? idims.Skip(idims[0] == -1 ? 1 : 0).ToArray() : new int[0];
}
for (int j = 0; j < dims.Length; j++)
dims[j] = dims[j] == -1 ? 0 : dims[j];
if (dims == null || dims.Length == 0)
Expand Down Expand Up @@ -876,6 +882,15 @@ internal sealed class Options : TransformInputBase
/// </remarks>
[Argument(ArgumentType.AtMostOnce, HelpText = "Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3].", SortOrder = 16)]
public bool AddBatchDimensionInputs = false;

/// <summary>
/// If the first dimension of the output is unknown, should it be treated as batched or not. e.g. output = [-1] will be read as a vector of unknown length when this is false.
/// </summary>
/// <remarks>
/// This parameter is used to deal with models that have unknown output shape and it needs to be interpreted in ML.NET as a vector of unkown length and not as a batch dimension.
/// </remarks>
[Argument(ArgumentType.AtMostOnce, HelpText = "If the first dimension of the output is unknown, should it be treated as batched or not. e.g. output = [-1] will be read as a vector of unkown length when this is false.", SortOrder = 17)]
michaelgsharp marked this conversation as resolved.
Show resolved Hide resolved
public bool TreatOutputAsBatched = true;
}

private readonly IHost _host;
Expand All @@ -897,7 +912,7 @@ internal TensorFlowEstimator(IHostEnvironment env, string[] outputColumnNames, s
}

internal TensorFlowEstimator(IHostEnvironment env, Options options)
: this(env, options, TensorFlowUtils.LoadTensorFlowModel(env, options.ModelLocation))
: this(env, options, TensorFlowUtils.LoadTensorFlowModel(env, options.ModelLocation, options.TreatOutputAsBatched))
{
}

Expand All @@ -906,20 +921,23 @@ internal TensorFlowEstimator(IHostEnvironment env, Options options, TensorFlowMo
_host = Contracts.CheckRef(env, nameof(env)).Register(nameof(TensorFlowEstimator));
_options = options;
_tensorFlowModel = tensorFlowModel;
if (!tensorFlowModel.TreatOutputAsBatched)
_options.TreatOutputAsBatched = tensorFlowModel.TreatOutputAsBatched;
tensorFlowModel.Session.graph.as_default();
var inputTuple = TensorFlowTransformer.GetInputInfo(_host, tensorFlowModel.Session, options.InputColumns);
var inputTuple = TensorFlowTransformer.GetInputInfo(_host, tensorFlowModel.Session, _options.InputColumns);
_tfInputTypes = inputTuple.tfInputTypes;
var outputTuple = TensorFlowTransformer.GetOutputInfo(_host, tensorFlowModel.Session, options.OutputColumns);
var outputTuple = TensorFlowTransformer.GetOutputInfo(_host, tensorFlowModel.Session, _options.OutputColumns, _options.TreatOutputAsBatched);
_outputTypes = outputTuple.outputTypes;
}

private static Options CreateArguments(TensorFlowModel tensorFlowModel, string[] outputColumnNames, string[] inputColumnName, bool addBatchDimensionInput)
private static Options CreateArguments(TensorFlowModel tensorFlowModel, string[] outputColumnNames, string[] inputColumnName, bool addBatchDimensionInput, bool treatOutputAsBatched = true)
{
var options = new Options();
options.ModelLocation = tensorFlowModel.ModelPath;
options.InputColumns = inputColumnName;
options.OutputColumns = outputColumnNames;
options.AddBatchDimensionInputs = addBatchDimensionInput;
options.TreatOutputAsBatched = treatOutputAsBatched;
return options;
}

Expand Down Expand Up @@ -959,7 +977,7 @@ public TensorFlowTransformer Fit(IDataView input)
if (_transformer == null)
{
_transformer = new TensorFlowTransformer(_host, _tensorFlowModel.Session, _options.OutputColumns, _options.InputColumns,
IsSavedModel(_host, _options.ModelLocation) ? _options.ModelLocation : null, false, _options.AddBatchDimensionInputs);
IsSavedModel(_host, _options.ModelLocation) ? _options.ModelLocation : null, false, _options.AddBatchDimensionInputs, treatOutputAsBatched: _options.TreatOutputAsBatched);
}
// Validate input schema.
_transformer.GetOutputSchema(input.Schema);
Expand Down
27 changes: 19 additions & 8 deletions src/Microsoft.ML.TensorFlow/TensorflowUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ internal static class TensorFlowUtils
/// </summary>
internal const string TensorflowUpstreamOperatorsKind = "TensorflowUpstreamOperators";

internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph graph, string opType = null)
internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph graph, string opType = null, bool treatOutputAsBatched = true)
michaelgsharp marked this conversation as resolved.
Show resolved Hide resolved
{
var schemaBuilder = new DataViewSchema.Builder();
foreach (Operation op in graph)
Expand Down Expand Up @@ -79,7 +79,7 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap
// Construct the final ML.NET type of a Tensorflow variable.
var tensorShape = op.output.TensorShape.dims;

if(tensorShape == null)
if (tensorShape == null)
{
// primitive column type
schemaBuilder.AddColumn(op.name, mlType, metadataBuilder.ToAnnotations());
Expand All @@ -90,7 +90,16 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap
DataViewType columnType = new VectorDataViewType(mlType);
if (!(Utils.Size(tensorShape) == 1 && tensorShape[0] <= 0) &&
(Utils.Size(tensorShape) > 0 && tensorShape.Skip(1).All(x => x > 0)))
columnType = new VectorDataViewType(mlType, tensorShape[0] > 0 ? tensorShape : tensorShape.Skip(1).ToArray());
if (treatOutputAsBatched)
michaelgsharp marked this conversation as resolved.
Show resolved Hide resolved
{
columnType = new VectorDataViewType(mlType, tensorShape[0] > 0 ? tensorShape : tensorShape.Skip(1).ToArray());
}
else
{
if (tensorShape[0] < 0)
tensorShape[0] = 0;
columnType = new VectorDataViewType(mlType, tensorShape);
}

schemaBuilder.AddColumn(op.name, columnType, metadataBuilder.ToAnnotations());
}
Expand All @@ -108,22 +117,24 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap
/// </summary>
/// <param name="env">The environment to use.</param>
/// <param name="modelPath">Model to load.</param>
internal static DataViewSchema GetModelSchema(IHostEnvironment env, string modelPath)
/// <param name="treatOutputAsBatched">If the first dimension of the output is unknown, should it be treated as batched or not.</param>
internal static DataViewSchema GetModelSchema(IHostEnvironment env, string modelPath, bool treatOutputAsBatched = true)
{
using var model = LoadTensorFlowModel(env, modelPath);
return GetModelSchema(env, model.Session.graph);
using var model = LoadTensorFlowModel(env, modelPath, treatOutputAsBatched);
return GetModelSchema(env, model.Session.graph, treatOutputAsBatched: treatOutputAsBatched);
}

/// <summary>
/// Load TensorFlow model into memory.
/// </summary>
/// <param name="env">The environment to use.</param>
/// <param name="modelPath">The model to load.</param>
/// <param name="treatOutputAsBatched">If the first dimension of the output is unknown, should it be treated as batched or not.</param>
/// <returns></returns>
internal static TensorFlowModel LoadTensorFlowModel(IHostEnvironment env, string modelPath)
internal static TensorFlowModel LoadTensorFlowModel(IHostEnvironment env, string modelPath, bool treatOutputAsBatched = true)
{
var session = GetSession(env, modelPath);
return new TensorFlowModel(env, session, modelPath);
return new TensorFlowModel(env, session, modelPath, treatOutputAsBatched: treatOutputAsBatched);
}

internal static PrimitiveDataViewType Tf2MlNetType(TF_DataType type)
Expand Down
9 changes: 9 additions & 0 deletions test/BaselineOutput/Common/EntryPoints/core_manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -23613,6 +23613,15 @@
"SortOrder": 16.0,
"IsNullable": false,
"Default": false
},
{
"Name": "TreatOutputAsBatched",
"Type": "Bool",
"Desc": "If the first dimension of the output is unknown, should it be treated as batched or not. e.g. output = [-1] will be read as a vector of unkown length when this is false.",
"Required": false,
"SortOrder": 17.0,
"IsNullable": false,
"Default": true
}
],
"Outputs": [
Expand Down
Loading