Skip to content

Commit

Permalink
Tensorflow fix (#5547)
Browse files Browse the repository at this point in the history
* fix tensorflow issue on sample repo

* add comments
  • Loading branch information
frank-dong-ms-zz authored Dec 11, 2020
1 parent 5318cc2 commit 3e72d19
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 13 deletions.
13 changes: 12 additions & 1 deletion src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,18 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) :
var shape = originalShape.dims;

if (shape == null || (shape.Length == 0))
_fullySpecifiedShapes[i] = new TensorShape();
{
// for vector type input TensorShape should same to dim
if (_isInputVector[i])
{
vecType = (VectorDataViewType)type;
var colTypeDims = vecType.Dimensions.Select(dim => (int)dim).ToArray();
_fullySpecifiedShapes[i] = new TensorShape(colTypeDims);
}
else
// for primitive type use default TensorShape
_fullySpecifiedShapes[i] = new TensorShape();
}
else
{
vecType = (VectorDataViewType)type;
Expand Down
26 changes: 18 additions & 8 deletions src/Microsoft.ML.TensorFlow/TensorflowUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,6 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap
if (mlType == null || op.NumOutputs <= 0)
continue;

// Construct the final ML.NET type of a Tensorflow variable.
var tensorShape = op.output.TensorShape.dims;
var columnType = new VectorDataViewType(mlType);
if (!(Utils.Size(tensorShape) == 1 && tensorShape[0] <= 0) &&
(Utils.Size(tensorShape) > 0 && tensorShape.Skip(1).All(x => x > 0)))
columnType = new VectorDataViewType(mlType, tensorShape[0] > 0 ? tensorShape : tensorShape.Skip(1).ToArray());

// There can be at most two metadata fields.
// 1. The first field always presents. Its value is this operator's type. For example,
// if an output is produced by an "Softmax" operator, the value of this field should be "Softmax".
Expand All @@ -83,7 +76,24 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap
(ref VBuffer<ReadOnlyMemory<char>> value) => { upstreamOperatorNames.CopyTo(ref value); });
}

schemaBuilder.AddColumn(op.name, columnType, metadataBuilder.ToAnnotations());
// Construct the final ML.NET type of a Tensorflow variable.
var tensorShape = op.output.TensorShape.dims;

if(tensorShape == null)
{
// primitive column type
schemaBuilder.AddColumn(op.name, mlType, metadataBuilder.ToAnnotations());
}
else
{
// vector column type
DataViewType columnType = new VectorDataViewType(mlType);
if (!(Utils.Size(tensorShape) == 1 && tensorShape[0] <= 0) &&
(Utils.Size(tensorShape) > 0 && tensorShape.Skip(1).All(x => x > 0)))
columnType = new VectorDataViewType(mlType, tensorShape[0] > 0 ? tensorShape : tensorShape.Skip(1).ToArray());

schemaBuilder.AddColumn(op.name, columnType, metadataBuilder.ToAnnotations());
}
}
return schemaBuilder.ToSchema();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1262,10 +1262,10 @@ class TextOutput

class PrimitiveInput
{
[LoadColumn(0, 1)]
[LoadColumn(0)]
public string input1;

[LoadColumn(1, 2)]
[LoadColumn(1)]
public string input2;
}

Expand Down Expand Up @@ -1305,8 +1305,10 @@ public void TensorFlowPrimitiveInputTest()
{
using var tensorFlowModel = _mlContext.Model.LoadTensorFlowModel(@"model_primitive_input_test");
var schema = tensorFlowModel.GetModelSchema();
Assert.True(schema.TryGetColumnIndex("input1", out var colIndex));
Assert.True(schema.TryGetColumnIndex("input2", out colIndex));
Assert.True(schema.GetColumnOrNull("input1").HasValue);
Assert.True(schema.GetColumnOrNull("input1").Value.Type is TextDataViewType);
Assert.True(schema.GetColumnOrNull("input2").HasValue);
Assert.True(schema.GetColumnOrNull("input2").Value.Type is TextDataViewType);

var dataview = _mlContext.Data.CreateTextLoader<PrimitiveInput>().Load(new MultiFileSource(null));

Expand Down

0 comments on commit 3e72d19

Please sign in to comment.