Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay/TOPI][ONNX/TFLite] Refactor MATRIX_SET_DIAG Operator for Relay… #9329

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,14 +456,10 @@ struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {

/*! \brief Attributes used in matrix_set_diag operator */
struct MatrixSetDiagAttrs : public tvm::AttrsNode<MatrixSetDiagAttrs> {
int k1;
int k2;
bool super_diag_right_align;
bool sub_diag_right_align;

TVM_DECLARE_ATTRS(MatrixSetDiagAttrs, "relay.attrs.MatrixSetDiagAttrs") {
TVM_ATTR_FIELD(k1).set_default(0).describe("Lower limit (included) of the range of diagonals.");
TVM_ATTR_FIELD(k2).set_default(0).describe("Upper limit (included) of the range of diagonals.");
TVM_ATTR_FIELD(super_diag_right_align)
.set_default(true)
.describe("Bool, true iff super-diagonal is right aligned (left-padded).");
Expand Down
22 changes: 11 additions & 11 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1759,14 +1759,13 @@ inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Array<PrimExpr
* \param tag output tensor tag.
* \return new tensor with given diagonal values.
*/
inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k1, int k2,
bool super_diag_right_align, bool sub_diag_right_align,
inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, const Tensor& k1,
const Tensor& k2, bool super_diag_right_align,
bool sub_diag_right_align,
const std::string name = "T_matrix_set_diag",
const std::string tag = kInjective) {
size_t ndim = input->shape.size() - 1;

bool only_one_diagonal = k1 == k2;

return compute(
input->shape,
[&](const Array<Var>& iter_vars) {
Expand All @@ -1776,12 +1775,10 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k
for (size_t i = 0; i < ndim - 1; i++) {
diagonal_indices.push_back(iter_vars[i]);
}
if (only_one_diagonal) {
k = k1;
} else {
auto multi_diagonals = [&]() {
// Determining which diagonal/sub-diagonal/super-diagonal it is
k = iter_vars[ndim] - iter_vars[ndim - 1];
diagonal_indices.push_back(k2 - k);
diagonal_indices.push_back(k2(0) - k);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe (though please check) that you can just do k2() for 0-D tensor access.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

k2() will report error

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eh ok


// Calculating the offset in diagonal tensor for this diagonal
auto get_offset = [&](PrimExpr M, PrimExpr N) {
Expand All @@ -1794,13 +1791,16 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k
: 0,
sub_diag_right_align ? get_offset(input->shape[ndim], input->shape[ndim - 1] + k)
: 0);
}
return k;
};
auto get_k = [&]() { return if_then_else(k1(0) == k2(0), k1(0), multi_diagonals()); };
k = get_k();
diagonal_indices.push_back(if_then_else(k >= 0, iter_vars[ndim - 1], iter_vars[ndim]) +
offset);
return diagonal(diagonal_indices);
};
return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1,
if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2,
return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1(0),
if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2(0),
get_diag(), input(iter_vars)),
input(iter_vars));
},
Expand Down
30 changes: 30 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4226,6 +4226,35 @@ def _impl_v1(cls, inputs, attr, params):
return _expr.TupleWrapper(_expr.Tuple(result), len(result))


class Trilu(OnnxOpConverter):
"""Operator converter for Trilu"""

@classmethod
def _impl_v14(cls, inputs, attr, params):
upper = attr.get("upper", 1)
input_shape = shape_of(inputs[0])
input_dims = infer_shape(input_shape)[0]
data_type = infer_type(inputs[0]).checked_type.dtype
k_tensor = relay.const(np.asarray([0], dtype=np.int64))
if len(inputs) == 2:
k_tensor = inputs[1]

diag_input = relay.zeros(fold_constant(shape_of(inputs[0])), dtype=data_type)

if upper == 0:
k1 = relay.add(k_tensor, relay.const(1, dtype="int64"))
k2 = relay.take(input_shape, relay.const(input_dims - 1, dtype="int32"))
k2 = relay.expand_dims(k2, axis=0)
return relay.matrix_set_diag(inputs[0], diag_input, k=(k1, k2))
else:
k1 = relay.take(input_shape, relay.const(input_dims-2, dtype="int32"))
k1 = relay.multiply(k1, relay.const(-1, dtype="int64"))
k1 = relay.subtract(k1, relay.const(1, dtype="int64"))
k1 = relay.expand_dims(k1, axis=0)
k2 = relay.subtract(k_tensor, relay.const(1, dtype="int64"))
return relay.matrix_set_diag(inputs[0], diag_input, k=(k1, k2))


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -4425,6 +4454,7 @@ def _get_convert_map(opset):
"Adagrad": Adagrad.get_converter(opset),
"Adam": Adam.get_converter(opset),
"Momentum": Momentum.get_converter(opset),
"Trilu": Trilu.get_converter(opset),
}


Expand Down
14 changes: 11 additions & 3 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3293,6 +3293,11 @@ def convert_matrix_set_diag(self, op):

input_expr = self.get_tensor_expr(input_tensors[0])
diagonal_expr = self.get_tensor_expr(input_tensors[1])
diag_shape = to_int_list(self.get_tensor_shape(input_tensors[1]))
input_shape = to_int_list(self.get_tensor_shape(input_tensors[0]))
if len(diag_shape) == len(input_shape) - 1:
diag_shape = np.insert(diag_shape, len(diag_shape)-1, 1)
diagonal_expr = _op.reshape(diagonal_expr, diag_shape)

out = _op.matrix_set_diag(input_expr, diagonal_expr)
return out
Expand All @@ -3313,13 +3318,16 @@ def convert_matrix_diag(self, op):
scale and zero points to be equal"

shape = to_int_list(self.get_tensor_shape(diagonal))
shape = np.append(shape, shape[-1])
shape_copy = np.copy(shape)
diag_shape = np.insert(shape, len(shape)-1, 1).astype(np.int32)

shape = np.append(shape_copy, shape[-1]).astype(np.int32)
dtype = self.get_tensor_type_str(diagonal.tensor.Type())

input_expr = _op.zeros(tuple(shape), dtype)
diagonal_expr = self.get_tensor_expr(diagonal)

out = _op.matrix_set_diag(input_expr, diagonal_expr)
out = _op.matrix_set_diag(input_expr, _op.reshape(diagonal_expr, diag_shape))
return out

def convert_densify(self, op):
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from . import _make
from .dyn import _make as _dyn_make
from .tensor import shape_of
import numpy as np


def cast(data, dtype):
Expand Down Expand Up @@ -1352,6 +1353,11 @@ def matrix_set_diag(data, diagonal, k=0, align="RIGHT_LEFT"):
k_one = k
k_two = k

if not isinstance(k_one, Expr):
k_one = const(np.asarray([k_one], dtype=np.int64))
if not isinstance(k_two, Expr):
k_two = const(np.asarray([k_two], dtype=np.int64))

super_diag_right_align = align[:5] == "RIGHT"
sub_diag_right_align = align[-5:] == "RIGHT"

Expand Down
38 changes: 13 additions & 25 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3708,59 +3708,45 @@ TVM_REGISTER_NODE_TYPE(MatrixSetDiagAttrs);
bool MatrixSetDiagRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [input, diagonal, result]
ICHECK_EQ(types.size(), 3);
ICHECK_EQ(types.size(), 5);

const auto* input = types[0].as<TensorTypeNode>();
ICHECK(input);

const auto* diagonal = types[1].as<TensorTypeNode>();
ICHECK(diagonal);

const auto param = attrs.as<MatrixSetDiagAttrs>();
ICHECK_GE(param->k2, param->k1);

int d_ndims = diagonal->shape.size();
int i_ndims = input->shape.size();
const auto* k1 = types[2].as<TensorTypeNode>();
ICHECK(k1);

reporter->Assert(input->shape[i_ndims - 2] > -param->k1);
reporter->Assert(input->shape[i_ndims - 1] > param->k2);
const auto* k2 = types[3].as<TensorTypeNode>();
ICHECK(k2);

int d_ndims = diagonal->shape.size();
for (int i = 0; i < d_ndims - 2; i++) {
reporter->AssertEQ(input->shape[i], diagonal->shape[i]);
}
if (param->k1 != param->k2) {
reporter->AssertEQ(diagonal->shape[d_ndims - 2], param->k2 - param->k1 + 1);
} else if (d_ndims >= 2) {
reporter->AssertEQ(input->shape[d_ndims - 2], diagonal->shape[d_ndims - 2]);
}
auto max_diag_len = if_then_else(input->shape[i_ndims - 2] + (param->k2 > 0 ? param->k2 : 0) <=
input->shape[i_ndims - 1] + (param->k1 < 0 ? -param->k1 : 0),
input->shape[i_ndims - 2] + (param->k2 > 0 ? param->k2 : 0),
input->shape[i_ndims - 1] + (param->k1 < 0 ? -param->k1 : 0));
reporter->AssertEQ(diagonal->shape[d_ndims - 1], max_diag_len);

reporter->Assign(types[2], TensorType(input->shape, input->dtype));
reporter->Assign(types[4], TensorType(input->shape, input->dtype));
return true;
}

Array<te::Tensor> MatrixSetDiagCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<MatrixSetDiagAttrs>();
ICHECK(param != nullptr);
return Array<te::Tensor>{topi::matrix_set_diag(inputs[0], inputs[1], param->k1, param->k2,
return Array<te::Tensor>{topi::matrix_set_diag(inputs[0], inputs[1], inputs[2], inputs[3],
param->super_diag_right_align,
param->sub_diag_right_align)};
}

Expr MakeMatrixSetDiag(Expr input, Expr diagonal, int k1, int k2, bool super_diag_right_align,
Expr MakeMatrixSetDiag(Expr input, Expr diagonal, Expr k1, Expr k2, bool super_diag_right_align,
bool sub_diag_right_align) {
auto attrs = make_object<MatrixSetDiagAttrs>();
attrs->k1 = k1;
attrs->k2 = k2;
attrs->super_diag_right_align = super_diag_right_align;
attrs->sub_diag_right_align = sub_diag_right_align;
static const Op& op = Op::Get("matrix_set_diag");
return Call(op, {input, diagonal}, Attrs(attrs), {});
return Call(op, {input, diagonal, k1, k2}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.matrix_set_diag").set_body_typed(MakeMatrixSetDiag);
Expand All @@ -3776,9 +3762,11 @@ RELAY_REGISTER_OP("matrix_set_diag")
**sub_diag_right_align** Bool, true iff sub-diagonal is right aligned (left-padded).
)code" TVM_ADD_FILELINE)
.set_attrs_type<MatrixSetDiagAttrs>()
.set_num_inputs(2)
.set_num_inputs(4)
.add_argument("input", "Tensor", "Input Tensor.")
.add_argument("diagonal", "Tensor", "Values to be filled in the diagonal.")
.add_argument("k1", "Tensor", "ILower limit (included) of the range of diagonals.")
.add_argument("k2", "Tensor", "Upper limit (included) of the range of diagonals.")
.set_support_level(10)
.add_type_rel("MatrixSetDiag", MatrixSetDiagRel)
.set_attr<FTVMCompute>("FTVMCompute", MatrixSetDiagCompute)
Expand Down
5 changes: 2 additions & 3 deletions src/topi/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,10 @@ TVM_REGISTER_GLOBAL("topi.one_hot").set_body([](TVMArgs args, TVMRetValue* rv) {
});

TVM_REGISTER_GLOBAL("topi.matrix_set_diag").set_body([](TVMArgs args, TVMRetValue* rv) {
int k1 = args[2];
int k2 = args[3];
bool super_diag_right_align = args[4];
bool sub_diag_right_align = args[5];
*rv = matrix_set_diag(args[0], args[1], k1, k2, super_diag_right_align, sub_diag_right_align);
*rv = matrix_set_diag(args[0], args[1], args[2], args[3], super_diag_right_align,
sub_diag_right_align);
});

TVM_REGISTER_GLOBAL("topi.adv_index").set_body([](TVMArgs args, TVMRetValue* rv) {
Expand Down
42 changes: 42 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5773,6 +5773,47 @@ def repeat(N, D):
repeat(2, D),
)

@tvm.testing.parametrize_targets
def test_trilu(target, dev):
def verify_trilu(in_shape, k, upper):
trilu_node = helper.make_node('Trilu', inputs=["x", "k"], outputs=["out"], upper=upper)
graph = helper.make_graph(
[trilu_node],
"trilu_test",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT, list(in_shape)),
helper.make_tensor_value_info("k", TensorProto.INT64, list((1,))),
],
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(in_shape))],
)

model = helper.make_model(graph, producer_name="trilu_test")
input_array = np.random.rand(*in_shape).astype("float32")
verify_with_ort_with_inputs(model, [input_array, np.asarray(k)], target=target, dev=dev, use_vm=True)

in_shape = (4, 5)
verify_trilu(in_shape, [4], 0)
verify_trilu(in_shape, [5], 0)
verify_trilu(in_shape, [5], 1)
verify_trilu(in_shape, [-1], 0)
verify_trilu(in_shape, [-1], 1)
verify_trilu(in_shape, [-7], 0)
verify_trilu(in_shape, [-7], 1)
verify_trilu(in_shape, [6], 0)
verify_trilu(in_shape, [6], 1)

in_shape = (3, 1, 5)
verify_trilu(in_shape, [0], 0)
verify_trilu(in_shape, [1], 1)
verify_trilu(in_shape, [6], 0)
verify_trilu(in_shape, [6], 1)

in_shape = (3, 5, 5)
verify_trilu(in_shape, [0], 0)
verify_trilu(in_shape, [0], 1)
verify_trilu(in_shape, [-1], 0)
verify_trilu(in_shape, [-1], 1)


if __name__ == "__main__":
test_flatten()
Expand Down Expand Up @@ -5864,3 +5905,4 @@ def repeat(N, D):
test_convinteger()
test_batch_matmul()
test_global_lppool()
test_trilu()
8 changes: 7 additions & 1 deletion tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,13 @@ def test_matrix_set_diag():
def _verify(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"):
input = relay.var("input", relay.TensorType(input_shape, dtype))
diagonal = relay.var("diagonal", relay.TensorType(diagonal_shape, dtype))
out = relay.matrix_set_diag(input, diagonal, k, align)
out = None
if len(diagonal_shape) == len(input_shape) - 1:
new_shape = list(diagonal_shape)
new_shape.insert(-1, 1)
out = relay.matrix_set_diag(input, relay.reshape(diagonal, new_shape), k, align)
else:
out = relay.matrix_set_diag(input, diagonal, k, align)

in_type = run_infer_type(input)
out_type = run_infer_type(out)
Expand Down
24 changes: 18 additions & 6 deletions tests/python/topi/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,21 +753,36 @@ def check_device(target, dev):
def verify_matrix_set_diag(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"):
input = te.placeholder(shape=input_shape, name="input", dtype=dtype)
diagonal = te.placeholder(shape=diagonal_shape, name="diagonal", dtype=dtype)
matrix_set_diag_result = topi.transform.matrix_set_diag(input, diagonal, k, align)
k1 = te.placeholder(shape=(1,), name="k1", dtype="int64")
k2 = te.placeholder(shape=(1,), name="k2", dtype="int64")
matrix_set_diag_result = topi.transform.matrix_set_diag(input, diagonal, (k1, k2), align)

k_one, k_two = None, None
if isinstance(k, (tuple, list)):
k_one = k[0]
if len(k) >= 2:
k_two = k[1]
else:
k_two = k[0]
else:
k_one = k
k_two = k

def check_device(target, dev):
dev = tvm.device(target, 0)
print("Running on target: %s" % target)
with tvm.target.Target(target):
s = tvm.topi.testing.get_injective_schedule(target)(matrix_set_diag_result)
fn = tvm.build(s, [input, diagonal, matrix_set_diag_result], target, name="matrix_set_diag")
fn = tvm.build(s, [input, diagonal, k1, k2, matrix_set_diag_result], target, name="matrix_set_diag")
input_npy = np.random.randint(-100, 100, size=input_shape).astype(dtype)
diagonal_npy = np.random.randint(-100, 100, size=diagonal_shape).astype(dtype)
out_npy = tvm.topi.testing.matrix_set_diag(input_npy, diagonal_npy, k, align)
input_nd = tvm.nd.array(input_npy, dev)
diagonal_nd = tvm.nd.array(diagonal_npy, dev)
k1_nd = tvm.nd.array(np.asarray([k_one]), dev)
k2_nd = tvm.nd.array(np.asarray([k_two]), dev)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(matrix_set_diag_result.dtype), dev)
fn(input_nd, diagonal_nd, out_nd)
fn(input_nd, diagonal_nd, k1_nd, k2_nd, out_nd)
out_topi = out_nd.numpy()
tvm.testing.assert_allclose(out_topi, out_npy)

Expand Down Expand Up @@ -1235,9 +1250,6 @@ def test_sparse_to_dense():
@tvm.testing.uses_gpu
def test_matrix_set_diag():
for dtype in ["float32", "int32"]:
verify_matrix_set_diag((2, 2), (2,), dtype)
verify_matrix_set_diag((4, 3, 3), (4, 3), dtype)
verify_matrix_set_diag((2, 3, 4), (2, 3), dtype, 1)
verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "LEFT_RIGHT")
verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "LEFT_LEFT")
verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "RIGHT_RIGHT")
Expand Down