Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Op] Trilu operator implementation #12124

Merged
merged 7 commits into from
Aug 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,15 @@ struct StftAttrs : public tvm::AttrsNode<StftAttrs> {
}
}; // struct StftAttrs

struct TriluAttrs : public tvm::AttrsNode<TriluAttrs> {
bool upper;

TVM_DECLARE_ATTRS(TriluAttrs, "relay.attrs.TriluAttrs") {
TVM_ATTR_FIELD(upper).set_default(true).describe(
"Whether to keep the upper or lower half of the diagonal.");
}
}; // struct TriluAttrs

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
15 changes: 15 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4685,6 +4685,20 @@ def _impl_v12(cls, inputs, attr, params):
return _op.einsum(inputs, equation)


class Trilu(OnnxOpConverter):
"""Operator converter for Trilu"""

@classmethod
def _impl_v14(cls, inputs, attr, params):
upper = attr.get("upper", True)
if len(inputs) == 2:
data, k = inputs
else:
data = inputs[0]
k = 0
return _op.trilu(data, k, upper)


class RandomNormal(OnnxOpConverter):
"""Operator converter for random_normal"""

Expand Down Expand Up @@ -5345,6 +5359,7 @@ def _get_convert_map(opset):
"CumSum": CumSum.get_converter(opset),
"Unique": Unique.get_converter(opset),
"Einsum": Einsum.get_converter(opset),
"Trilu": Trilu.get_converter(opset),
# defs/control_flow
"Loop": Loop.get_converter(opset),
"If": If.get_converter(opset),
Expand Down
35 changes: 8 additions & 27 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,31 +318,6 @@ def square(self, inputs, input_types):
(dtype,) = input_types
return _op.power(inputs[0], _expr.const(2, dtype))

def tril(self, inputs, input_types):
data = inputs[0]
if len(inputs) == 2:
k_value = inputs[1]
else:
k_value = 0
input_shape = self.infer_shape(data)
k1, k2 = input_shape[-2:]
k1 = k_value + 1
diag_input = _op.zeros(input_shape, dtype=input_types[0])
return _op.matrix_set_diag(data, diag_input, k=(k1, k2))

def triu(self, inputs, input_types):
data = inputs[0]
if len(inputs) == 2:
k_value = inputs[1]
else:
k_value = 0
input_shape = self.infer_shape(data)
k1, k2 = input_shape[-2:]
k1 = (k1 * -1) - 1
k2 = k_value - 1
diag_input = _op.zeros(input_shape, dtype=input_types[0])
return _op.matrix_set_diag(data, diag_input, k=(k1, k2))

def lerp(self, inputs, input_types):
if len(inputs) != 3:
msg = "Wrong number of arguments (%d) to parse." % (len(inputs))
Expand Down Expand Up @@ -3405,6 +3380,12 @@ def grid_sampler(self, inputs, input_types):
inputs[0], grid, interpolate_str, layout, padding_mode_str, align_corners
)

def trilu(self, inputs, input_types, mode):
data = inputs[0]
k = inputs[1] if inputs[1] else 0
upper = True if mode == "triu" else False
return _op.trilu(data, k, upper)

# Operator mappings
def create_convert_map(self):
self.convert_map = {
Expand Down Expand Up @@ -3567,8 +3548,8 @@ def create_convert_map(self):
"aten::sqrt": self.make_unary("sqrt"),
"aten::rsqrt": self.make_unary("rsqrt"),
"aten::square": self.square,
"aten::tril": self.tril,
"aten::triu": self.triu,
"aten::tril": functools.partial(self.trilu, mode="tril"),
"aten::triu": functools.partial(self.trilu, mode="triu"),
"aten::ceil": self.make_unary("ceil"),
"aten::floor": self.make_unary("floor"),
"aten::round": self.make_unary("round"),
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ def stft_shape_func(attrs, inputs, _):
]


# trilu
_reg.register_strategy("trilu", strategy.trilu_strategy)


# scatter_add
@_reg.register_compute("scatter_add")
def compute_scatter_add(attrs, inputs, output_type):
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,3 +617,8 @@ class NLLLossAttrs(Attrs):
@tvm._ffi.register_object("relay.attrs.FixedPointMultiplyAttrs")
class FixedPointMultiplyAttrs(Attrs):
"""Attributes used in fixed_point_multiply operators"""


@tvm._ffi.register_object("relay.attrs.TriluAttrs")
class TriluAttrs(Attrs):
"""Attributes used in trilu operators"""
28 changes: 28 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,6 +1460,34 @@ def _compute_stft(attrs, inputs, output_type):
return _compute_stft


# trilu
@override_native_generic_func("trilu_strategy")
def trilu_strategy(attrs, outs, out_type, target):
"""trilu generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_trilu(topi.trilu),
wrap_topi_schedule(topi.generic.schedule_extern),
name="trilu.generic",
)
return strategy


def wrap_compute_trilu(topi_compute):
"""Wrap trilu compute"""

def _compute_trilu(attrs, inputs, output_type):
return [
topi_compute(
inputs[0],
inputs[1],
attrs.upper,
)
]

return _compute_trilu


# roi_pool
@generic_func
def schedule_roi_pool(attrs, outs, target):
Expand Down
43 changes: 43 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1889,3 +1889,46 @@ def stft(
window = _make.ones([n_fft], "int32")

return _make.stft(data, n_fft, hop_length, win_length, window, normalized, onesided)


def trilu(data, k, upper=True):
"""
Given a 2-D matrix or batches of 2-D matrices, returns the
upper or lower triangular part of the tensor.

Parameters
----------
data: relay.Expr
The tensor that trilu will be applied to. Must be either
a 2D matrix or a tensor of batches of 2D matrices.

k: int
The number of diagonals above or below the main diagonal
to exclude or include.

upper: bool, optional
If True, only upper triangular values of input are kept,
if False, the lower triangular values are kept.


Returns
-------
ret : relay.Expr
The new tensor with appropriate diagonals set to zero.

Examples
--------
.. code-block:: python

x = [[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]

relay.trilu(x, True, 0) =
[[0, 1, 2],
[0, 4, 5],
[0, 0, 8]]
"""
if not isinstance(k, Expr):
jwfromm marked this conversation as resolved.
Show resolved Hide resolved
k = const(k, dtype="int32")
return _make.trilu(data, k, upper)
58 changes: 58 additions & 0 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,3 +1001,61 @@ def sliding_window(data, axis, window_shape, strides):
The resulting tensor.
"""
return cpp.sliding_window(data, axis, window_shape, strides)


def trilu(data, k, upper):
"""
Given a 2-D matrix or batches of 2-D matrices, returns the
upper or lower triangular part of the tensor.

Parameters
----------
data: tvm.te.Tensor
The tensor that trilu will be applied to. Must be either
a 2D matrix or a tensor of batches of 2D matrices.

k: tvm.te.Tensor
The number of diagonals above or below the main diagonal
to exclude or include.

upper: bool
If True, only upper triangular values of input are kept,
if False, the lower triangular values are kept.


Returns
-------
ret : relay.Expr
The new tensor with appropriate diagonals set to zero.

Examples
--------
.. code-block:: python

x = [[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]

relay.trilu(x, True, 0) =
[[0, 1, 2],
[0, 4, 5],
[0, 0, 8]]
"""
# Make sure datatype is consistent.
if k.dtype != "int32":
k = tvm.tir.Cast("int32", k)

# Check either above or below diagonal depending on upper.
check_op = tvm.tir.GE
if upper:
check_op = tvm.tir.LE

def _apply_trilu(*indices):
row_index = indices[-2]
col_index = indices[-1]
other_indices = indices[:-2]
check_position = check_op(row_index, col_index - k)
value = data(*other_indices, row_index, col_index)
return tvm.tir.Select(check_position, value, tvm.tir.const(0, data.dtype))

return te.compute(data.shape, _apply_trilu, name="trilu")
50 changes: 50 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4230,5 +4230,55 @@ RELAY_REGISTER_OP("invert_permutation")
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<TOpIsStateful>("TOpIsStateful", false);

// Trilu

TVM_REGISTER_NODE_TYPE(TriluAttrs);

bool TriluRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types: [data, k, result]
ICHECK_EQ(types.size(), 3) << "Trilu: expect 3 types but " << types.size() << " provided";
ICHECK_EQ(num_inputs, 2) << "Trilu: expect 2 inputs but " << num_inputs << " provided";
auto data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
ICHECK(types[0].as<IncompleteTypeNode>())
<< "Trilu: expect input type to be TensorType but get " << types[0];
return false;
}

auto k = types[1].as<TensorTypeNode>();
if (k == nullptr) {
ICHECK(types[1].as<IncompleteTypeNode>())
<< "Trilu: expect k type to be TensorType but get " << types[1];
return false;
}

ICHECK(k->shape.size() == 0) << "Trilu: k must be a 0-D tensor but get " << k;

// Output shape is the same as input shape.
reporter->Assign(types[2], TensorType(data->shape, data->dtype));
return true;
}

Expr MakeTrilu(Expr data, Expr k, bool upper) {
auto attrs = make_object<TriluAttrs>();
attrs->upper = upper;
static const Op& op = Op::Get("trilu");
return Call(op, {data, k}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.trilu").set_body_typed(MakeTrilu);

RELAY_REGISTER_OP("trilu")
.describe(
R"code(Filters out the upper or lower portion of an input tensor on one side of a diagonal.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor")
.add_argument("k", "Tensor", "The number of diagonals above or below the main to exclude.")
.add_type_rel("trilu", TriluRel)
.set_support_level(3)
.set_attr<TOpPattern>("TOpPattern", kElemWise);

} // namespace relay
} // namespace tvm
16 changes: 0 additions & 16 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5242,23 +5242,7 @@ def verify_eyelike(indata, dynamic=False):
"test_training_dropout_mask",
"test_training_dropout_zero_ratio",
"test_training_dropout_zero_ratio_mask",
"test_tril",
"test_tril_pos",
"test_tril_square",
"test_tril_square_neg",
"test_tril_neg",
"test_tril_one_row_neg",
"test_tril_out_neg",
"test_tril_out_pos",
"test_tril_zero",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't looked at this op at all. How tricky would it be to support the zero case? Otherwise LGTM.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It actually works on llvm and cuda. I was testing on my macbook and it seems like the metal backend in general doesnt support empty tensors. I think for CI we could add these cases.

Copy link
Contributor Author

@jwfromm jwfromm Jul 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like it also doesnt work with nvptx for the same issue with empty tensors. I'll add them here and see how it does in CI.

"test_triu",
"test_triu_one_row",
"test_triu_out_neg_out",
"test_triu_out_pos",
"test_triu_neg",
"test_triu_pos",
"test_triu_square",
"test_triu_square_neg",
"test_triu_zero",
"test_unique_sorted_with_axis",
"test_unique_sorted_with_axis_3d",
Expand Down
10 changes: 10 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4616,5 +4616,15 @@ def test_fn(x, y, w):
verify_model(test_fn, [x, y, w[0]])


def test_trilu():
def _test_trilu(op, diagonal):
return lambda inp: op(inp, diagonal)

for op in [torch.triu, torch.tril]:
verify_model(_test_trilu(op, 0), [torch.rand(size=[3, 3])])
verify_model(_test_trilu(op, 1), [torch.rand(size=[6, 6])])
verify_model(_test_trilu(op, -2), [torch.rand(size=[6, 6])])


if __name__ == "__main__":
pytest.main([__file__])
29 changes: 29 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2207,5 +2207,34 @@ def test_stft(
)


def test_trilu(target="llvm", dev=tvm.cpu()):
def verify_trilu(data_shape, upper=True, k=0):
data = relay.var("data", relay.TensorType(data_shape, "float32"))
y = relay.trilu(data, k, upper)
mod = tvm.ir.IRModule.from_expr(y)

data_np = np.random.normal(size=data_shape).astype("float32")
tvm_res = (
relay.create_executor("graph", mod=mod, device=dev, target=target)
.evaluate()(data_np)
.numpy()
)
if upper:
np_res = np.triu(data_np, k)
else:
np_res = np.tril(data_np, k)
tvm.testing.assert_allclose(tvm_res, np_res)

# Test upper and lower triangle
verify_trilu((3, 3), True, 0)
verify_trilu((3, 3), False, 0)
# Test larger matrices with offset.
verify_trilu((6, 6), True, 1)
verify_trilu((6, 6), False, 2)
verify_trilu((6, 6), False, -2)
# Test batch size
verify_trilu((8, 6, 6), False, -2)


if __name__ == "__main__":
tvm.testing.main()
Loading