Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[Relay][Op] Trilu operator implementation (apache#12124)
Browse files Browse the repository at this point in the history
* Added topi trilu implementation

* Implemented and tested full Trilu op.

* Fix test type.

* Add tril zero tests.

* Add pytorch trilu integration.

* Clean up torch integration.

* Readded skip for zero tests.
  • Loading branch information
Josh Fromm authored and xinetzone committed Nov 25, 2022
1 parent fe4b8c8 commit d2126de
Show file tree
Hide file tree
Showing 13 changed files with 298 additions and 43 deletions.
9 changes: 9 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,15 @@ struct StftAttrs : public tvm::AttrsNode<StftAttrs> {
}
}; // struct StftAttrs

struct TriluAttrs : public tvm::AttrsNode<TriluAttrs> {
bool upper;

TVM_DECLARE_ATTRS(TriluAttrs, "relay.attrs.TriluAttrs") {
TVM_ATTR_FIELD(upper).set_default(true).describe(
"Whether to keep the upper or lower half of the diagonal.");
}
}; // struct TriluAttrs

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
15 changes: 15 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4685,6 +4685,20 @@ def _impl_v12(cls, inputs, attr, params):
return _op.einsum(inputs, equation)


class Trilu(OnnxOpConverter):
"""Operator converter for Trilu"""

@classmethod
def _impl_v14(cls, inputs, attr, params):
upper = attr.get("upper", True)
if len(inputs) == 2:
data, k = inputs
else:
data = inputs[0]
k = 0
return _op.trilu(data, k, upper)


class RandomNormal(OnnxOpConverter):
"""Operator converter for random_normal"""

Expand Down Expand Up @@ -5345,6 +5359,7 @@ def _get_convert_map(opset):
"CumSum": CumSum.get_converter(opset),
"Unique": Unique.get_converter(opset),
"Einsum": Einsum.get_converter(opset),
"Trilu": Trilu.get_converter(opset),
# defs/control_flow
"Loop": Loop.get_converter(opset),
"If": If.get_converter(opset),
Expand Down
35 changes: 8 additions & 27 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,31 +318,6 @@ def square(self, inputs, input_types):
(dtype,) = input_types
return _op.power(inputs[0], _expr.const(2, dtype))

def tril(self, inputs, input_types):
data = inputs[0]
if len(inputs) == 2:
k_value = inputs[1]
else:
k_value = 0
input_shape = self.infer_shape(data)
k1, k2 = input_shape[-2:]
k1 = k_value + 1
diag_input = _op.zeros(input_shape, dtype=input_types[0])
return _op.matrix_set_diag(data, diag_input, k=(k1, k2))

def triu(self, inputs, input_types):
data = inputs[0]
if len(inputs) == 2:
k_value = inputs[1]
else:
k_value = 0
input_shape = self.infer_shape(data)
k1, k2 = input_shape[-2:]
k1 = (k1 * -1) - 1
k2 = k_value - 1
diag_input = _op.zeros(input_shape, dtype=input_types[0])
return _op.matrix_set_diag(data, diag_input, k=(k1, k2))

def lerp(self, inputs, input_types):
if len(inputs) != 3:
msg = "Wrong number of arguments (%d) to parse." % (len(inputs))
Expand Down Expand Up @@ -3405,6 +3380,12 @@ def grid_sampler(self, inputs, input_types):
inputs[0], grid, interpolate_str, layout, padding_mode_str, align_corners
)

def trilu(self, inputs, input_types, mode):
data = inputs[0]
k = inputs[1] if inputs[1] else 0
upper = True if mode == "triu" else False
return _op.trilu(data, k, upper)

# Operator mappings
def create_convert_map(self):
self.convert_map = {
Expand Down Expand Up @@ -3567,8 +3548,8 @@ def create_convert_map(self):
"aten::sqrt": self.make_unary("sqrt"),
"aten::rsqrt": self.make_unary("rsqrt"),
"aten::square": self.square,
"aten::tril": self.tril,
"aten::triu": self.triu,
"aten::tril": functools.partial(self.trilu, mode="tril"),
"aten::triu": functools.partial(self.trilu, mode="triu"),
"aten::ceil": self.make_unary("ceil"),
"aten::floor": self.make_unary("floor"),
"aten::round": self.make_unary("round"),
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ def stft_shape_func(attrs, inputs, _):
]


# trilu
_reg.register_strategy("trilu", strategy.trilu_strategy)


# scatter_add
@_reg.register_compute("scatter_add")
def compute_scatter_add(attrs, inputs, output_type):
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,3 +617,8 @@ class NLLLossAttrs(Attrs):
@tvm._ffi.register_object("relay.attrs.FixedPointMultiplyAttrs")
class FixedPointMultiplyAttrs(Attrs):
"""Attributes used in fixed_point_multiply operators"""


@tvm._ffi.register_object("relay.attrs.TriluAttrs")
class TriluAttrs(Attrs):
"""Attributes used in trilu operators"""
28 changes: 28 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,6 +1460,34 @@ def _compute_stft(attrs, inputs, output_type):
return _compute_stft


# trilu
@override_native_generic_func("trilu_strategy")
def trilu_strategy(attrs, outs, out_type, target):
"""trilu generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_trilu(topi.trilu),
wrap_topi_schedule(topi.generic.schedule_extern),
name="trilu.generic",
)
return strategy


def wrap_compute_trilu(topi_compute):
"""Wrap trilu compute"""

def _compute_trilu(attrs, inputs, output_type):
return [
topi_compute(
inputs[0],
inputs[1],
attrs.upper,
)
]

return _compute_trilu


# roi_pool
@generic_func
def schedule_roi_pool(attrs, outs, target):
Expand Down
43 changes: 43 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1889,3 +1889,46 @@ def stft(
window = _make.ones([n_fft], "int32")

return _make.stft(data, n_fft, hop_length, win_length, window, normalized, onesided)


def trilu(data, k, upper=True):
"""
Given a 2-D matrix or batches of 2-D matrices, returns the
upper or lower triangular part of the tensor.
Parameters
----------
data: relay.Expr
The tensor that trilu will be applied to. Must be either
a 2D matrix or a tensor of batches of 2D matrices.
k: int
The number of diagonals above or below the main diagonal
to exclude or include.
upper: bool, optional
If True, only upper triangular values of input are kept,
if False, the lower triangular values are kept.
Returns
-------
ret : relay.Expr
The new tensor with appropriate diagonals set to zero.
Examples
--------
.. code-block:: python
x = [[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]
relay.trilu(x, True, 0) =
[[0, 1, 2],
[0, 4, 5],
[0, 0, 8]]
"""
if not isinstance(k, Expr):
k = const(k, dtype="int32")
return _make.trilu(data, k, upper)
58 changes: 58 additions & 0 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,3 +1001,61 @@ def sliding_window(data, axis, window_shape, strides):
The resulting tensor.
"""
return cpp.sliding_window(data, axis, window_shape, strides)


def trilu(data, k, upper):
"""
Given a 2-D matrix or batches of 2-D matrices, returns the
upper or lower triangular part of the tensor.
Parameters
----------
data: tvm.te.Tensor
The tensor that trilu will be applied to. Must be either
a 2D matrix or a tensor of batches of 2D matrices.
k: tvm.te.Tensor
The number of diagonals above or below the main diagonal
to exclude or include.
upper: bool
If True, only upper triangular values of input are kept,
if False, the lower triangular values are kept.
Returns
-------
ret : relay.Expr
The new tensor with appropriate diagonals set to zero.
Examples
--------
.. code-block:: python
x = [[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]
relay.trilu(x, True, 0) =
[[0, 1, 2],
[0, 4, 5],
[0, 0, 8]]
"""
# Make sure datatype is consistent.
if k.dtype != "int32":
k = tvm.tir.Cast("int32", k)

# Check either above or below diagonal depending on upper.
check_op = tvm.tir.GE
if upper:
check_op = tvm.tir.LE

def _apply_trilu(*indices):
row_index = indices[-2]
col_index = indices[-1]
other_indices = indices[:-2]
check_position = check_op(row_index, col_index - k)
value = data(*other_indices, row_index, col_index)
return tvm.tir.Select(check_position, value, tvm.tir.const(0, data.dtype))

return te.compute(data.shape, _apply_trilu, name="trilu")
50 changes: 50 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4230,5 +4230,55 @@ RELAY_REGISTER_OP("invert_permutation")
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<TOpIsStateful>("TOpIsStateful", false);

// Trilu

TVM_REGISTER_NODE_TYPE(TriluAttrs);

bool TriluRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types: [data, k, result]
ICHECK_EQ(types.size(), 3) << "Trilu: expect 3 types but " << types.size() << " provided";
ICHECK_EQ(num_inputs, 2) << "Trilu: expect 2 inputs but " << num_inputs << " provided";
auto data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
ICHECK(types[0].as<IncompleteTypeNode>())
<< "Trilu: expect input type to be TensorType but get " << types[0];
return false;
}

auto k = types[1].as<TensorTypeNode>();
if (k == nullptr) {
ICHECK(types[1].as<IncompleteTypeNode>())
<< "Trilu: expect k type to be TensorType but get " << types[1];
return false;
}

ICHECK(k->shape.size() == 0) << "Trilu: k must be a 0-D tensor but get " << k;

// Output shape is the same as input shape.
reporter->Assign(types[2], TensorType(data->shape, data->dtype));
return true;
}

Expr MakeTrilu(Expr data, Expr k, bool upper) {
auto attrs = make_object<TriluAttrs>();
attrs->upper = upper;
static const Op& op = Op::Get("trilu");
return Call(op, {data, k}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.trilu").set_body_typed(MakeTrilu);

RELAY_REGISTER_OP("trilu")
.describe(
R"code(Filters out the upper or lower portion of an input tensor on one side of a diagonal.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor")
.add_argument("k", "Tensor", "The number of diagonals above or below the main to exclude.")
.add_type_rel("trilu", TriluRel)
.set_support_level(3)
.set_attr<TOpPattern>("TOpPattern", kElemWise);

} // namespace relay
} // namespace tvm
16 changes: 0 additions & 16 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5242,23 +5242,7 @@ def verify_eyelike(indata, dynamic=False):
"test_training_dropout_mask",
"test_training_dropout_zero_ratio",
"test_training_dropout_zero_ratio_mask",
"test_tril",
"test_tril_pos",
"test_tril_square",
"test_tril_square_neg",
"test_tril_neg",
"test_tril_one_row_neg",
"test_tril_out_neg",
"test_tril_out_pos",
"test_tril_zero",
"test_triu",
"test_triu_one_row",
"test_triu_out_neg_out",
"test_triu_out_pos",
"test_triu_neg",
"test_triu_pos",
"test_triu_square",
"test_triu_square_neg",
"test_triu_zero",
"test_unique_sorted_with_axis",
"test_unique_sorted_with_axis_3d",
Expand Down
10 changes: 10 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4616,5 +4616,15 @@ def test_fn(x, y, w):
verify_model(test_fn, [x, y, w[0]])


def test_trilu():
def _test_trilu(op, diagonal):
return lambda inp: op(inp, diagonal)

for op in [torch.triu, torch.tril]:
verify_model(_test_trilu(op, 0), [torch.rand(size=[3, 3])])
verify_model(_test_trilu(op, 1), [torch.rand(size=[6, 6])])
verify_model(_test_trilu(op, -2), [torch.rand(size=[6, 6])])


if __name__ == "__main__":
pytest.main([__file__])
29 changes: 29 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2207,5 +2207,34 @@ def test_stft(
)


def test_trilu(target="llvm", dev=tvm.cpu()):
def verify_trilu(data_shape, upper=True, k=0):
data = relay.var("data", relay.TensorType(data_shape, "float32"))
y = relay.trilu(data, k, upper)
mod = tvm.ir.IRModule.from_expr(y)

data_np = np.random.normal(size=data_shape).astype("float32")
tvm_res = (
relay.create_executor("graph", mod=mod, device=dev, target=target)
.evaluate()(data_np)
.numpy()
)
if upper:
np_res = np.triu(data_np, k)
else:
np_res = np.tril(data_np, k)
tvm.testing.assert_allclose(tvm_res, np_res)

# Test upper and lower triangle
verify_trilu((3, 3), True, 0)
verify_trilu((3, 3), False, 0)
# Test larger matrices with offset.
verify_trilu((6, 6), True, 1)
verify_trilu((6, 6), False, 2)
verify_trilu((6, 6), False, -2)
# Test batch size
verify_trilu((8, 6, 6), False, -2)


if __name__ == "__main__":
tvm.testing.main()
Loading

0 comments on commit d2126de

Please sign in to comment.