Skip to content

Commit

Permalink
backport tf2bc changes from barracuda-release (#3341)
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Elion authored Feb 5, 2020
1 parent b3755a5 commit 5a1bab7
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions ml-agents/mlagents/trainers/tensorflow_to_barracuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,9 +592,12 @@ def get_attr(node, attr_name, default=None):
val = node.attr[attr_name]

if val.HasField("list"):
return val.list.i
# NOTE: can't find way to identify type of list BUT it is almost always list(int)
# except list(float) in FractionalAvg/MaxPool
if len(val.list.shape) > 0:
return val.list.shape
else:
return val.list.i
if val.HasField("b"):
return val.b
if val.HasField("i"):
Expand All @@ -618,6 +621,10 @@ def get_epsilon(layer):

def get_layer_rank(layer):
shape = get_attr(layer, "shape")
if not shape:
outputShapes = get_attr(layer, "_output_shapes")
if outputShapes:
shape = outputShapes[0]
if not shape:
return None
if isinstance(shape, list):
Expand Down Expand Up @@ -753,7 +760,7 @@ def axis_to_barracuda(axis, input_rank):
W = 2
C = 3
if axis < 0:
axis = input_rank - axis
axis = input_rank + axis
assert axis >= 0
assert axis < input_rank
if input_rank == 4:
Expand Down

0 comments on commit 5a1bab7

Please sign in to comment.