Skip to content

Commit

Permalink
fix TensorflowUtil.GetModelSchema
Browse files Browse the repository at this point in the history
  • Loading branch information
Oceania2018 committed Aug 1, 2019
1 parent e24270a commit 04d152a
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Dnn/Microsoft.ML.Dnn.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
<ItemGroup>
<PackageReference Include="System.IO.FileSystem.AccessControl" Version="$(SystemIOFileSystemAccessControl)" />
<PackageReference Include="System.Security.Principal.Windows" Version="$(SystemSecurityPrincipalWindows)" />
<PackageReference Include="TensorFlow.NET" Version="0.10.4" />
<PackageReference Include="TensorFlow.NET" Version="0.10.6" />
</ItemGroup>

<ItemGroup>
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Dnn/TensorflowUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ internal static Session LoadTFSession(IExceptionContext ectx, byte[] modelBytes,
var graph = new Graph();
try
{
graph.Import(modelBytes);
graph.Import(modelBytes, "");
}
catch (Exception ex)
{
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ internal static class TensorFlowUtils
internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph graph, string opType = null)
{
var schemaBuilder = new DataViewSchema.Builder();
foreach (Operation op in graph.get_operations())
foreach (var op in graph)
{
if (opType != null && opType != op.OpType)
continue;
Expand Down
31 changes: 17 additions & 14 deletions src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,7 @@ private void Dispose(bool disposing)
{
if (Session != IntPtr.Zero)
{
Session.close();
Session.Dispose();
Session.close(); // invoked Dispose()
}
}
finally
Expand Down Expand Up @@ -796,39 +795,43 @@ public Tensor GetTensor()
var sbyteBuffer = (sbyte[])Convert.ChangeType(_denseData, typeof(sbyte[]));
return new Tensor(sbyteBuffer, _dims);
}

if (typeof(T) == typeof(ulong))
else if (typeof(T) == typeof(ulong))
{
var longBuffer = (ulong[])Convert.ChangeType(_denseData, typeof(ulong[]));
return new Tensor(longBuffer, _dims);
}

if (typeof(T) == typeof(UInt32))
else if (typeof(T) == typeof(UInt32))
{
var uint32Buffer = (UInt32[])Convert.ChangeType(_denseData, typeof(UInt32[]));
return new Tensor(uint32Buffer, _dims);
}

if (typeof(T) == typeof(UInt16))
else if (typeof(T) == typeof(UInt16))
{
var uint16Buffer = (UInt16[])Convert.ChangeType(_denseData, typeof(UInt16[]));
return new Tensor(uint16Buffer, _dims);
}

if (typeof(T) == typeof(bool))
else if (typeof(T) == typeof(bool))
{
return new Tensor(new NDArray(_denseData, _tfShape), TF_DataType.TF_BOOL);
}

if (typeof(T) == typeof(float))
else if (typeof(T) == typeof(float))
{
return new Tensor(new NDArray(_denseData, _tfShape), TF_DataType.TF_FLOAT);
}

if (typeof(T) == typeof(double))
else if (typeof(T) == typeof(double))
{
return new Tensor(new NDArray(_denseData, _tfShape), TF_DataType.TF_DOUBLE);
}
else if (typeof(T) == typeof(System.ReadOnlyMemory<char>))
{
byte[][] bytes = new byte[_vBuffer.Length][];
for (int i = 0; i < bytes.Length; i++)
{
bytes[i] = Encoding.UTF8.GetBytes(((System.ReadOnlyMemory<char>)(object)_denseData[i]).ToArray());
}

return new Tensor(bytes, _tfShape.dims.Select(x => (long)x).ToArray());
}

return new Tensor(new NDArray(_denseData, _tfShape)); //TFTensor.Create(_denseData, _vBuffer.Length, _tfShape);
}
Expand Down

0 comments on commit 04d152a

Please sign in to comment.