Skip to content

Commit

Permalink
[ONNX] Add OpSet 13 implementation for Hardmax (apache#8924)
Browse files Browse the repository at this point in the history
* Add opset 13 impl for hardmax

* Format
  • Loading branch information
michalpiszczek authored and ylc committed Sep 29, 2021
1 parent b9099d8 commit 05ed659
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
17 changes: 17 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1899,6 +1899,23 @@ def _impl_v1(cls, inputs, attr, params):
)
return _op.reshape(onehot, shape_of(inputs[0]))

@classmethod
def _impl_v13(cls, inputs, attr, params) -> relay.Expr:
inferred_type = infer_type(inputs[0])
dtype = inferred_type.checked_type.dtype
ndim = len(inferred_type.checked_type.shape)
axis = attr.get("axis", -1) % ndim

argmax = _op.argmax(inputs[0], axis=axis)
return _op.one_hot(
argmax,
_op.const(1.0, dtype),
_op.const(0.0, dtype),
fold_constant(_op.take(shape_of(inputs[0]), _op.const([axis], "int64"))),
axis,
dtype,
)


class OneHot(OnnxOpConverter):
"""Operator converter for OneHot."""
Expand Down
3 changes: 0 additions & 3 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4722,9 +4722,6 @@ def verify_eyelike(indata):
"test_einsum_transpose",
"test_greater_equal",
"test_greater_equal_bcast",
"test_hardmax_axis_0",
"test_hardmax_axis_1",
"test_hardmax_default_axis",
"test_if_seq",
"test_less_equal",
"test_less_equal_bcast",
Expand Down

0 comments on commit 05ed659

Please sign in to comment.