diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index c9d7b11f62dd..33e5340654cc 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1468,6 +1468,19 @@ def _impl_v11(cls, inputs, attr, params): ) +class EyeLike(OnnxOpConverter): + """Operator converter for EyeLike.""" + + @classmethod + def _impl_v9(cls, inputs, attr, params): + in_dtype = infer_type(inputs[0]).checked_type.dtype + zeros = _op.zeros_like(inputs[0]) + dim = infer_shape(zeros)[0] + indices = _op.arange(_op.const(0), _op.const(dim), dtype="int32") + ones = _op.full(_op.const(1), (dim,), dtype=in_dtype) + return _op.scatter_nd(zeros, _op.stack([indices, indices], axis=0), ones, "update") + + class Greater(OnnxOpConverter): """Operator logical greater.""" @@ -3134,6 +3147,7 @@ def _get_convert_map(opset): "NonZero": NonZero.get_converter(opset), "Range": Range.get_converter(opset), "CumSum": CumSum.get_converter(opset), + "EyeLike": EyeLike.get_converter(opset), # defs/control_flow "Loop": Loop.get_converter(opset), "If": If.get_converter(opset),