diff --git a/src/Microsoft.ML.TensorFlow/TensorFlowModel.cs b/src/Microsoft.ML.TensorFlow/TensorFlowModel.cs index f8fd1b476e..7bff665ba6 100644 --- a/src/Microsoft.ML.TensorFlow/TensorFlowModel.cs +++ b/src/Microsoft.ML.TensorFlow/TensorFlowModel.cs @@ -18,6 +18,7 @@ public sealed class TensorFlowModel : IDisposable { internal Session Session { get; } internal string ModelPath { get; } + internal bool TreatOutputAsBatched { get; } private readonly IHostEnvironment _env; @@ -27,10 +28,12 @@ public sealed class TensorFlowModel : IDisposable /// An object. /// TensorFlow session object. /// Location of the model from where was loaded. - internal TensorFlowModel(IHostEnvironment env, Session session, string modelLocation) + /// If the first dimension of the output is unknown, should it be treated as batched or not. + internal TensorFlowModel(IHostEnvironment env, Session session, string modelLocation, bool treatOutputAsBatched = true) { Session = session; ModelPath = modelLocation; + TreatOutputAsBatched = treatOutputAsBatched; _env = env; _disposed = false; } @@ -40,7 +43,7 @@ internal TensorFlowModel(IHostEnvironment env, Session session, string modelLoca /// public DataViewSchema GetModelSchema() { - return TensorFlowUtils.GetModelSchema(_env, Session.graph); + return TensorFlowUtils.GetModelSchema(_env, Session.graph, TreatOutputAsBatched); } /// @@ -49,7 +52,7 @@ public DataViewSchema GetModelSchema() /// public DataViewSchema GetInputSchema() { - return TensorFlowUtils.GetModelSchema(_env, Session.graph, "Placeholder"); + return TensorFlowUtils.GetModelSchema(_env, Session.graph, TreatOutputAsBatched, "Placeholder"); } /// diff --git a/src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs b/src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs index 2ad63321d4..372d4b1029 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs @@ -35,5 +35,31 @@ public static class TensorflowCatalog /// public static TensorFlowModel LoadTensorFlowModel(this ModelOperationsCatalog catalog, string modelLocation) => TensorFlowUtils.LoadTensorFlowModel(CatalogUtils.GetEnvironment(catalog), modelLocation); + + /// + /// Load TensorFlow model into memory. This is the convenience method that allows the model to be loaded once and subsequently use it for querying schema and creation of + /// using . + /// usage of this API requires additional NuGet dependencies on TensorFlow redist, see linked document for more information. + /// also holds references to unmanaged resources that need to be freed either with an explicit + /// call to Dispose() or implicitly by declaring the variable with the "using" syntax/> + /// + /// + /// + /// + /// + /// The transform's catalog. + /// Location of the TensorFlow model. + /// If the first dimension of the output is unknown, should it be treated as batched or not. + /// + /// + /// + /// + /// + public static TensorFlowModel LoadTensorFlowModel(this ModelOperationsCatalog catalog, string modelLocation, bool treatOutputAsBatched) + => TensorFlowUtils.LoadTensorFlowModel(CatalogUtils.GetEnvironment(catalog), modelLocation, treatOutputAsBatched); } } diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index f5e85f158f..1537ab10cf 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -45,6 +45,7 @@ public sealed class TensorFlowTransformer : RowToRowTransformerBase, IDisposable private readonly string _savedModelPath; private readonly bool _isTemporarySavedModel; private readonly bool _addBatchDimensionInput; + private readonly bool _treatOutputAsBatched; internal readonly Session Session; internal readonly Runner Runner; internal readonly DataViewType[] OutputTypes; @@ -71,8 +72,9 @@ private static VersionInfo GetVersionInfo() modelSignature: "TENSFLOW", //verWrittenCur: 0x00010001, // Initial //verWrittenCur: 0x00010002, // Added Support for Multiple Outputs and SavedModel. - verWrittenCur: 0x00010003, // Added Support for adding batch dimension in inputs. - verReadableCur: 0x00010003, + //verWrittenCur: 0x00010003, // Added Support for adding batch dimension in inputs. + verWrittenCur: 0x00010004, // Added Support for treating batch as output or not. + verReadableCur: 0x00010004, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, loaderAssemblyName: typeof(TensorFlowTransformer).Assembly.FullName); @@ -82,16 +84,17 @@ private static VersionInfo GetVersionInfo() /// Transform for scoring Tensorflow models. Input data column names/types must exactly match /// all model input names. Only the output columns specified will be generated. /// This convenience method avoids reloading of TensorFlow model. - /// It is useful in a situation where user has already loaded TensorFlow model using for inspecting model schema. + /// It is useful in a situation where user has already loaded TensorFlow model using for inspecting model schema. /// /// The environment to use. - /// object created with . + /// object created with . /// The output columns to generate. Names must match model specifications. Data types are inferred from model. /// The name of the input data columns. Must match model's input names. If set to , the value of the will be used as source. /// Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3]. /// This parameter is used to deal with models that have unknown shape but the internal operators in the model require data to have batch dimension as well. - internal TensorFlowTransformer(IHostEnvironment env, TensorFlowModel tfModelInfo, string outputColumnName, string inputColumnName = null, bool addBatchDimensionInput = false) - : this(env, tfModelInfo.Session, new[] { outputColumnName }, new[] { inputColumnName ?? outputColumnName }, IsSavedModel(env, tfModelInfo.ModelPath) ? tfModelInfo.ModelPath : null, false, addBatchDimensionInput) + /// If the first dimension of the output is unknown, should it be treated as batched or not. + internal TensorFlowTransformer(IHostEnvironment env, TensorFlowModel tfModelInfo, string outputColumnName, string inputColumnName = null, bool addBatchDimensionInput = false, bool treatOutputAsBatched = true) + : this(env, tfModelInfo.Session, new[] { outputColumnName }, new[] { inputColumnName ?? outputColumnName }, IsSavedModel(env, tfModelInfo.ModelPath) ? tfModelInfo.ModelPath : null, false, addBatchDimensionInput, treatOutputAsBatched: treatOutputAsBatched) { } @@ -99,16 +102,17 @@ internal TensorFlowTransformer(IHostEnvironment env, TensorFlowModel tfModelInfo /// Transform for scoring Tensorflow models. Input data column names/types must exactly match /// all model input names. Only the output columns specified will be generated. /// This convenience method avoids reloading of TensorFlow model. - /// It is useful in a situation where user has already loaded TensorFlow model using for inspecting model schema. + /// It is useful in a situation where user has already loaded TensorFlow model using for inspecting model schema. /// /// The environment to use. - /// object created with . + /// object created with . /// The name of the input data columns. Must match model's input names. /// The output columns to generate. Names must match model specifications. Data types are inferred from model. /// Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3]. /// This parameter is used to deal with models that have unknown shape but the internal operators in the model require data to have batch dimension as well. - internal TensorFlowTransformer(IHostEnvironment env, TensorFlowModel tfModelInfo, string[] outputColumnNames, string[] inputColumnNames, bool addBatchDimensionInput = false) - : this(env, tfModelInfo.Session, outputColumnNames, inputColumnNames, IsSavedModel(env, tfModelInfo.ModelPath) ? tfModelInfo.ModelPath : null, false, addBatchDimensionInput) + /// If the first dimension of the output is unknown, should it be treated as batched or not. + internal TensorFlowTransformer(IHostEnvironment env, TensorFlowModel tfModelInfo, string[] outputColumnNames, string[] inputColumnNames, bool addBatchDimensionInput = false, bool treatOutputAsBatched = true) + : this(env, tfModelInfo.Session, outputColumnNames, inputColumnNames, IsSavedModel(env, tfModelInfo.ModelPath) ? tfModelInfo.ModelPath : null, false, addBatchDimensionInput, treatOutputAsBatched: treatOutputAsBatched) { } @@ -122,6 +126,7 @@ private static TensorFlowTransformer Create(IHostEnvironment env, ModelLoadConte // *** Binary format *** // byte: indicator for frozen models // byte: indicator for adding batch dimension in input + // byte: indicator for treating output as batched // stream: tensorFlow model. // int: number of input columns // for each input column @@ -129,13 +134,13 @@ private static TensorFlowTransformer Create(IHostEnvironment env, ModelLoadConte // int: number of output columns // for each output column // int: id of output column name - GetModelInfo(env, ctx, out string[] inputs, out string[] outputs, out bool isFrozen, out bool addBatchDimensionInput); + GetModelInfo(env, ctx, out string[] inputs, out string[] outputs, out bool isFrozen, out bool addBatchDimensionInput, out bool treatOutputAsBatched); if (isFrozen) { byte[] modelBytes = null; if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray())) throw env.ExceptDecode(); - return new TensorFlowTransformer(env, LoadTFSession(env, modelBytes), outputs, inputs, null, false, addBatchDimensionInput); + return new TensorFlowTransformer(env, LoadTFSession(env, modelBytes), outputs, inputs, null, false, addBatchDimensionInput, treatOutputAsBatched: treatOutputAsBatched); } var tempDirPath = Path.GetFullPath(Path.Combine(Path.GetTempPath(), nameof(TensorFlowTransformer) + "_" + Guid.NewGuid())); @@ -164,7 +169,7 @@ private static TensorFlowTransformer Create(IHostEnvironment env, ModelLoadConte } }); - return new TensorFlowTransformer(env, GetSession(env, tempDirPath), outputs, inputs, tempDirPath, true, addBatchDimensionInput); + return new TensorFlowTransformer(env, GetSession(env, tempDirPath), outputs, inputs, tempDirPath, true, addBatchDimensionInput, treatOutputAsBatched: treatOutputAsBatched); } catch (Exception) { @@ -236,7 +241,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); - private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out string[] inputs, out string[] outputs, out bool isFrozen, out bool addBatchDimensionInput) + private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out string[] inputs, out string[] outputs, out bool isFrozen, out bool addBatchDimensionInput, out bool treatOutputAsBatched) { isFrozen = true; bool isNonFrozenModelSupported = ctx.Header.ModelVerReadable >= 0x00010002; @@ -248,6 +253,11 @@ private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out if (isAddingBatchDimensionSupported) addBatchDimensionInput = ctx.Reader.ReadBoolByte(); + treatOutputAsBatched = true; + bool isTreatingOutputAsBatchedSupported = ctx.Header.ModelVerReadable >= 0x00010004; + if (isTreatingOutputAsBatchedSupported) + treatOutputAsBatched = ctx.Reader.ReadBoolByte(); + var numInputs = ctx.Reader.ReadInt32(); env.CheckDecode(numInputs > 0); inputs = new string[numInputs]; @@ -267,7 +277,7 @@ private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out internal TensorFlowTransformer(IHostEnvironment env, Session session, string[] outputColumnNames, string[] inputColumnNames, string savedModelPath, bool isTemporarySavedModel, - bool addBatchDimensionInput, int batchSize = 1, TensorFlowEstimator.Options options = null, IDataView input = null) + bool addBatchDimensionInput, int batchSize = 1, TensorFlowEstimator.Options options = null, IDataView input = null, bool treatOutputAsBatched = true) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TensorFlowTransformer))) { @@ -279,11 +289,12 @@ internal TensorFlowTransformer(IHostEnvironment env, Session session, string[] o _isTemporarySavedModel = isTemporarySavedModel; Session = session; _addBatchDimensionInput = addBatchDimensionInput; + _treatOutputAsBatched = treatOutputAsBatched; Inputs = inputColumnNames; Outputs = outputColumnNames; tf.compat.v1.disable_eager_execution(); - (TFOutputTypes, OutputTypes, TFOutputOperations) = GetOutputInfo(Host, Session, Outputs); + (TFOutputTypes, OutputTypes, TFOutputOperations) = GetOutputInfo(Host, Session, Outputs, treatOutputAsBatched); (TFInputTypes, TFInputShapes, TFInputOperations) = GetInputInfo(Host, Session, Inputs, batchSize); TFInputNodes = new TF_Output[Inputs.Length]; @@ -359,7 +370,7 @@ internal static TensorShape GetTensorShape(TF_Output output, Graph graph, Status return new TensorShape(dims.Select(x => (int)x).ToArray()); } - internal static (TF_DataType[] tfOutputTypes, DataViewType[] outputTypes, (Operation, int)[]) GetOutputInfo(IHost host, Session session, string[] outputs) + internal static (TF_DataType[] tfOutputTypes, DataViewType[] outputTypes, (Operation, int)[]) GetOutputInfo(IHost host, Session session, string[] outputs, bool treatOutputAsBatched) { var tfOutputTypes = new TF_DataType[outputs.Length]; var outputTypes = new DataViewType[outputs.Length]; @@ -384,7 +395,12 @@ internal static (TF_DataType[] tfOutputTypes, DataViewType[] outputTypes, (Opera // If there are other dimension that are unknown the transformer will return a variable length vector. // This is the work around in absence of reshape transformer. var idims = shape.dims; - int[] dims = shape.ndim > 0 ? idims.Skip(idims[0] == -1 ? 1 : 0).ToArray() : new int[0]; + + int[] dims = idims; + if (treatOutputAsBatched) + { + dims = shape.ndim > 0 ? idims.Skip(idims[0] == -1 ? 1 : 0).ToArray() : new int[0]; + } for (int j = 0; j < dims.Length; j++) dims[j] = dims[j] == -1 ? 0 : dims[j]; if (dims == null || dims.Length == 0) @@ -415,6 +431,7 @@ private protected override void SaveModel(ModelSaveContext ctx) // *** Binary format *** // byte: indicator for frozen models // byte: indicator for adding batch dimension in input + // byte: indicator for treating output as batched // stream: tensorFlow model. // int: number of input columns // for each input column @@ -425,6 +442,7 @@ private protected override void SaveModel(ModelSaveContext ctx) var isFrozen = string.IsNullOrEmpty(_savedModelPath); ctx.Writer.WriteBoolByte(isFrozen); ctx.Writer.WriteBoolByte(_addBatchDimensionInput); + ctx.Writer.WriteBoolByte(_treatOutputAsBatched); if (isFrozen) { using (var status = new Status()) @@ -876,6 +894,15 @@ internal sealed class Options : TransformInputBase /// [Argument(ArgumentType.AtMostOnce, HelpText = "Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3].", SortOrder = 16)] public bool AddBatchDimensionInputs = false; + + /// + /// If the first dimension of the output is unknown, should it be treated as batched or not. e.g. output = [-1] will be read as a vector of unknown length when this is false. + /// + /// + /// This parameter is used to deal with models that have unknown output shape and it needs to be interpreted in ML.NET as a vector of unknown length and not as a batch dimension. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "If the first dimension of the output is unknown, should it be treated as batched or not. e.g. output = [-1] will be read as a vector of unknown length when this is false.", SortOrder = 17)] + public bool TreatOutputAsBatched = true; } private readonly IHost _host; @@ -897,7 +924,7 @@ internal TensorFlowEstimator(IHostEnvironment env, string[] outputColumnNames, s } internal TensorFlowEstimator(IHostEnvironment env, Options options) - : this(env, options, TensorFlowUtils.LoadTensorFlowModel(env, options.ModelLocation)) + : this(env, options, TensorFlowUtils.LoadTensorFlowModel(env, options.ModelLocation, options.TreatOutputAsBatched)) { } @@ -906,20 +933,23 @@ internal TensorFlowEstimator(IHostEnvironment env, Options options, TensorFlowMo _host = Contracts.CheckRef(env, nameof(env)).Register(nameof(TensorFlowEstimator)); _options = options; _tensorFlowModel = tensorFlowModel; + if (!tensorFlowModel.TreatOutputAsBatched) + _options.TreatOutputAsBatched = tensorFlowModel.TreatOutputAsBatched; tensorFlowModel.Session.graph.as_default(); - var inputTuple = TensorFlowTransformer.GetInputInfo(_host, tensorFlowModel.Session, options.InputColumns); + var inputTuple = TensorFlowTransformer.GetInputInfo(_host, tensorFlowModel.Session, _options.InputColumns); _tfInputTypes = inputTuple.tfInputTypes; - var outputTuple = TensorFlowTransformer.GetOutputInfo(_host, tensorFlowModel.Session, options.OutputColumns); + var outputTuple = TensorFlowTransformer.GetOutputInfo(_host, tensorFlowModel.Session, _options.OutputColumns, _options.TreatOutputAsBatched); _outputTypes = outputTuple.outputTypes; } - private static Options CreateArguments(TensorFlowModel tensorFlowModel, string[] outputColumnNames, string[] inputColumnName, bool addBatchDimensionInput) + private static Options CreateArguments(TensorFlowModel tensorFlowModel, string[] outputColumnNames, string[] inputColumnName, bool addBatchDimensionInput, bool treatOutputAsBatched = true) { var options = new Options(); options.ModelLocation = tensorFlowModel.ModelPath; options.InputColumns = inputColumnName; options.OutputColumns = outputColumnNames; options.AddBatchDimensionInputs = addBatchDimensionInput; + options.TreatOutputAsBatched = treatOutputAsBatched; return options; } @@ -959,7 +989,7 @@ public TensorFlowTransformer Fit(IDataView input) if (_transformer == null) { _transformer = new TensorFlowTransformer(_host, _tensorFlowModel.Session, _options.OutputColumns, _options.InputColumns, - IsSavedModel(_host, _options.ModelLocation) ? _options.ModelLocation : null, false, _options.AddBatchDimensionInputs); + IsSavedModel(_host, _options.ModelLocation) ? _options.ModelLocation : null, false, _options.AddBatchDimensionInputs, treatOutputAsBatched: _options.TreatOutputAsBatched); } // Validate input schema. _transformer.GetOutputSchema(input.Schema); diff --git a/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs index 805aedcb3d..8fbbd772a0 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs @@ -32,7 +32,7 @@ internal static class TensorFlowUtils /// internal const string TensorflowUpstreamOperatorsKind = "TensorflowUpstreamOperators"; - internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph graph, string opType = null) + internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph graph, bool treatOutputAsBatched, string opType = null) { var schemaBuilder = new DataViewSchema.Builder(); foreach (Operation op in graph) @@ -79,7 +79,7 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap // Construct the final ML.NET type of a Tensorflow variable. var tensorShape = op.output.TensorShape.dims; - if(tensorShape == null) + if (tensorShape == null) { // primitive column type schemaBuilder.AddColumn(op.name, mlType, metadataBuilder.ToAnnotations()); @@ -90,7 +90,24 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap DataViewType columnType = new VectorDataViewType(mlType); if (!(Utils.Size(tensorShape) == 1 && tensorShape[0] <= 0) && (Utils.Size(tensorShape) > 0 && tensorShape.Skip(1).All(x => x > 0))) - columnType = new VectorDataViewType(mlType, tensorShape[0] > 0 ? tensorShape : tensorShape.Skip(1).ToArray()); + // treatOutputAsBatched == true means that if the first dimension is greater + // than 0 we take the tensor shape as is. If the first value is less then 0, we treat it as the batch input so we can + // ignore it for the shape of the ML.NET vector. I.E. if the input dimensions are [-1, 5], ML.NET will read the -1 as + // batch input, and so the ML.NET data type will be a vector of length 5. + if (treatOutputAsBatched) + { + columnType = new VectorDataViewType(mlType, tensorShape[0] > 0 ? tensorShape : tensorShape.Skip(1).ToArray()); + } + // When treatOutputAsBatched is false, if the first value is less than 0 we want to set it to 0. TensorFlow + // represents an unknown size as -1, but ML.NET represents it as 0 so we need to convert it. + // I.E. if the input dimensions are [-1, 5], ML.NET will read the -1 as a dimension of unknown length, and so the ML.NET + // data type will be a vector of 2 dimensions, where the first dimension is unknown and the second has a length of 5. + else + { + if (tensorShape[0] < 0) + tensorShape[0] = 0; + columnType = new VectorDataViewType(mlType, tensorShape); + } schemaBuilder.AddColumn(op.name, columnType, metadataBuilder.ToAnnotations()); } @@ -108,10 +125,11 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap /// /// The environment to use. /// Model to load. - internal static DataViewSchema GetModelSchema(IHostEnvironment env, string modelPath) + /// If the first dimension of the output is unknown, should it be treated as batched or not. + internal static DataViewSchema GetModelSchema(IHostEnvironment env, string modelPath, bool treatOutputAsBatched = true) { - using var model = LoadTensorFlowModel(env, modelPath); - return GetModelSchema(env, model.Session.graph); + using var model = LoadTensorFlowModel(env, modelPath, treatOutputAsBatched); + return GetModelSchema(env, model.Session.graph, treatOutputAsBatched); } /// @@ -119,11 +137,12 @@ internal static DataViewSchema GetModelSchema(IHostEnvironment env, string model /// /// The environment to use. /// The model to load. + /// If the first dimension of the output is unknown, should it be treated as batched or not. /// - internal static TensorFlowModel LoadTensorFlowModel(IHostEnvironment env, string modelPath) + internal static TensorFlowModel LoadTensorFlowModel(IHostEnvironment env, string modelPath, bool treatOutputAsBatched = true) { var session = GetSession(env, modelPath); - return new TensorFlowModel(env, session, modelPath); + return new TensorFlowModel(env, session, modelPath, treatOutputAsBatched: treatOutputAsBatched); } internal static PrimitiveDataViewType Tf2MlNetType(TF_DataType type) diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 429ff6bb78..7253e4533f 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -23613,6 +23613,15 @@ "SortOrder": 16.0, "IsNullable": false, "Default": false + }, + { + "Name": "TreatOutputAsBatched", + "Type": "Bool", + "Desc": "If the first dimension of the output is unknown, should it be treated as batched or not. e.g. output = [-1] will be read as a vector of unknown length when this is false.", + "Required": false, + "SortOrder": 17.0, + "IsNullable": false, + "Default": true } ], "Outputs": [ diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index 175e958f1b..3508cfe481 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -1152,7 +1152,6 @@ public void TensorFlowGettingSchemaMultipleTimes() } } - [TensorFlowFact] public void TensorFlowTransformCifarInvalidShape() { diff --git a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs index 8aba9b08a5..ff6dbd456f 100644 --- a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs @@ -184,6 +184,52 @@ public void TestTensorFlow() } } + [TensorFlowFact] + public void TreatOutputAsBatched() + { + var modelLocation = "cifar_model/frozen_model.pb"; + + var mlContext = new MLContext(seed: 1); + var imageHeight = 32; + var imageWidth = 32; + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + + var data = ML.Data.LoadFromTextFile(dataFile, new[] { + new TextLoader.Column("imagePath", DataKind.String, 0), + new TextLoader.Column("name", DataKind.String, 1) + }); + + // Note that CamelCase column names are there to match the TF graph node names. + // Check and make sure save/load work correctly for the new TreatOutputAsBatched value. + var pipe = ML.Transforms.LoadImages("Input", imageFolder, "imagePath") + .Append(ML.Transforms.ResizeImages("Input", imageHeight, imageWidth)) + .Append(ML.Transforms.ExtractPixels("Input", interleavePixelColors: true)) + .Append(ML.Model.LoadTensorFlowModel(modelLocation, false).ScoreTensorFlowModel("Output", "Input")); + + TestEstimatorCore(pipe, data); + var schema = pipe.Fit(data).Transform(data).Schema; + + // The dimensions of the output with treatOutputAsBatched set to false should be * 10 + // as the first dimension of -1 is treated as an unknown dimension. + Assert.Equal(new VectorDataViewType(NumberDataViewType.Single, 0, 10), schema["Output"].Type); + + // Note that CamelCase column names are there to match the TF graph node names. + // Test with TreatOutputAsBatched set to default value of true. + pipe = ML.Transforms.LoadImages("Input", imageFolder, "imagePath") + .Append(ML.Transforms.ResizeImages("Input", imageHeight, imageWidth)) + .Append(ML.Transforms.ExtractPixels("Input", interleavePixelColors: true)) + .Append(ML.Model.LoadTensorFlowModel(modelLocation).ScoreTensorFlowModel("Output", "Input")); + + TestEstimatorCore(pipe, data); + schema = pipe.Fit(data).Transform(data).Schema; + + // The dimensions of the output with treatOutputAsBatched set to true should be 10 + // as the first dimension of -1 is treated as the batch dimension. + Assert.Equal(new VectorDataViewType(NumberDataViewType.Single, 10), schema["Output"].Type); + + } + [TensorFlowFact] public void TestTensorFlowWithSchema() {