Skip to content

Commit

Permalink
Merge pull request #718 from sony/feature/20200918-Fix-FusedBatchNorm…
Browse files Browse the repository at this point in the history
…alization-for-ONNX-Exporter

Fix FusedBatchNormalization for ONNX Exporter.
  • Loading branch information
YukioOobuchi authored Sep 29, 2020
2 parents a5693a0 + 6d61642 commit 51bc881
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions python/src/nnabla/utils/converter/onnx/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
TensorProto.BOOL: np.bool,
TensorProto.UINT8: np.uint8,
TensorProto.INT8: np.int8,
TensorProto.INT32: np.uint32,
TensorProto.UINT32: np.uint32,
TensorProto.INT32: np.int32,
TensorProto.INT64: np.int64,
}
Expand Down Expand Up @@ -1079,7 +1079,7 @@ def BatchNormalization(self, opset, func, func_name="BatchNormalization"):

if input_shape_reshape != input_shape:
output_y_shape = np.array(
[d for d in self._var_dict[func.output[0]].dim])
[d for d in self._var_dict[func.input[0]].dim])
n = generate_reshape(self._model_proto.graph, outputs[0], func.output[0],
output_y_shape)
nl.append(n)
Expand All @@ -1092,29 +1092,30 @@ def FusedBatchNormalization(self, opset, func):
inputs = func.input[:]
outputs = func.output[:]

if len(func.input) != 6:
raise ValueError(
"The number of FusedBatchNormalization input must be 6")

del func.input[5]
if len(func.input) > 5:
del func.input[5]
bn_out = fork_name(func.input[0]) + "_bn"
func.output[0] = bn_out
nl.extend(self.BatchNormalization(
opset, func, func_name="FusedBatchNormalization"))

# Add
add_out = fork_name(func.input[0]) + "_add"
n = onnx.helper.make_node(
'Div',
[bn_out, inputs[5]],
[add_out],
)
nl.append(n)
if len(inputs) > 5:
# Add
add_out = fork_name(func.input[0]) + "_add"
n = onnx.helper.make_node(
'Add',
[bn_out, inputs[5]],
[add_out],
)
nl.append(n)
inputs = [add_out]
else:
inputs = [bn_out]

if nonlinearity == "relu":
# Relu
n = onnx.helper.make_node("Relu",
[add_out],
inputs,
outputs)
nl.append(n)
else:
Expand Down

0 comments on commit 51bc881

Please sign in to comment.