Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Add Einsum converter #8985

Merged
merged 7 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ struct ScanopAttrs : public tvm::AttrsNode<ScanopAttrs> {
.describe("The first element is not included")
.set_default(Bool(false));
}
};
}; // struct ScanopAttrs

/*! \brief Attributes used in unique operator */
struct UniqueAttrs : public tvm::AttrsNode<UniqueAttrs> {
Expand All @@ -489,6 +489,15 @@ struct UniqueAttrs : public tvm::AttrsNode<UniqueAttrs> {
}
}; // struct UniqueAttrs

/*! \brief Attributes used in einsum operator */
struct EinsumAttrs : public tvm::AttrsNode<EinsumAttrs> {
String equation;

TVM_DECLARE_ATTRS(EinsumAttrs, "relay.attrs.EinsumAttrs") {
TVM_ATTR_FIELD(equation).describe("The einsum expression string");
}
}; // struct EinsumAttrs

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
10 changes: 10 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3517,6 +3517,15 @@ def _impl_v11(cls, inputs, attr, params):
return _expr.TupleWrapper(_expr.Tuple([unique_vals, indices, inverse_indices, counts]), 4)


class Einsum(OnnxOpConverter):
"""Operator converter for Einsum"""

@classmethod
def _impl_v12(cls, inputs, attr, params):
equation = attr["equation"].decode("utf-8")
return _op.einsum(inputs, equation)


class RandomUniform(OnnxOpConverter):
"""Operator converter for random_uniform"""

Expand Down Expand Up @@ -3763,6 +3772,7 @@ def _get_convert_map(opset):
"Range": Range.get_converter(opset),
"CumSum": CumSum.get_converter(opset),
"Unique": Unique.get_converter(opset),
"Einsum": Einsum.get_converter(opset),
# defs/control_flow
"Loop": Loop.get_converter(opset),
"If": If.get_converter(opset),
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,11 @@ def compute_unique(attrs, inputs, output_type):
_reg.register_strategy("invert_permutation", strategy.invert_permutation_strategy)
_reg.register_shape_func("invert_permutation", False, elemwise_shape_func)


# einsum
_reg.register_strategy("einsum", strategy.einsum_strategy)
anwang2009 marked this conversation as resolved.
Show resolved Hide resolved


#####################
# Shape functions #
#####################
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,3 +1210,15 @@ def invert_permutation_strategy_cuda(attrs, inputs, out_type, target):
name="invert_permutation.cuda",
)
return strategy


@einsum_strategy.register(["cuda", "gpu"])
def einsum_strategy_cuda(attrs, inputs, out_type, target):
"""einsum cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
Copy link
Member

@masahi masahi Sep 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would lead to an error if you try to run it on cuda right? It is better to remove this strategy until we have a CUDA schedule ready. The error message would be clearer than the one from incorrect scheduling.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test_onnx_nodes tests run on CUDA and have passed. I will remove if you think it's necessary to remove despite that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh interesting. I didn't know that topi.generic.schedule_extern somehow generates a valid schedule.

wrap_compute_einsum(topi.cuda.einsum),
wrap_topi_schedule(topi.cuda.schedule_einsum),
name="einsum.cuda",
)
return strategy
21 changes: 21 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1669,3 +1669,24 @@ def invert_permutation_strategy(attrs, inputs, out_type, target):
name="invert_permutation.generic",
)
return strategy


def wrap_compute_einsum(topi_compute):
"""Wrap einsum topi compute"""

def _compute_einsum(attrs, inputs, _):
return [topi_compute(attrs.equation, *inputs)]

return _compute_einsum


@override_native_generic_func("einsum_strategy")
def einsum_strategy(attrs, inputs, out_type, target):
"""einsum generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_einsum(topi.einsum),
wrap_topi_schedule(topi.generic.schedule_einsum),
name="einsum.generic",
)
return strategy
23 changes: 23 additions & 0 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,29 @@ def concatenate(data, axis):
return _make.concatenate(Tuple(data), axis)


def einsum(data, equation):
"""Evaluates the Einstein summation convention on data
Parameters
----------
data : Union(List[relay.Expr], Tuple[relay.Expr])
A list of tensors.
equation : str
The einsum expression string.
Returns
-------
result : relay.Expr
The output tensor from the einsum op.
"""
data = list(data)
if not data:
raise ValueError("relay.einsum requires data to be non-empty.")
if not isinstance(equation, str):
raise ValueError("einsum `equation` must be a str")
return _make.einsum(Tuple(data), equation)


def stack(data, axis):
"""Join a sequence of arrays along a new axis.
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/topi/generic/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,19 @@ def schedule_unique(outs):
The computation schedule for the op.
"""
return _default_schedule(outs, False)


def schedule_einsum(outs):
anwang2009 marked this conversation as resolved.
Show resolved Hide resolved
"""Schedule for einsum operator.

Parameters
----------
outs: Array of Tensor
The computation graph description of einsum.

Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
80 changes: 80 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <tvm/tir/op.h>
#include <tvm/topi/broadcast.h>
#include <tvm/topi/detail/constant_utils.h>
#include <tvm/topi/einsum.h>
#include <tvm/topi/elemwise.h>
#include <tvm/topi/nn.h>
#include <tvm/topi/reduction.h>
Expand Down Expand Up @@ -2431,6 +2432,85 @@ RELAY_REGISTER_OP("broadcast_to_like")
.set_attr<FTVMCompute>("FTVMCompute", BroadCastToLikeCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);

// relay.einsum
TVM_REGISTER_NODE_TYPE(EinsumAttrs);

bool EinsumRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// Check attrs
const EinsumAttrs* param = attrs.as<EinsumAttrs>();
if (param == nullptr) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "the call attributes are not defined");
return false;
}

// types: [data, result]
ICHECK_EQ(types.size(), 2) << "the arity of einsum is 2, not " << types.size();

// Check input type is a tuple.
const auto* tensor_tuple = types[0].as<TupleTypeNode>();
if (tensor_tuple == nullptr) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "einsum requires a tuple of tensors as the first argument, found "
<< PrettyPrint(types[0]));
return false;
}

// Check the input tuple consists of tensors with consistent dtype.
const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]);
const DataType dtype = first->dtype;
std::vector<Array<PrimExpr>> input_shapes;
for (const Type& ele : tensor_tuple->fields) {
if (ele.as<IncompleteTypeNode>()) {
return false;
}

const auto& e = Downcast<TensorType>(ele);

const DataType& e_dtype = e->dtype;
if (e_dtype != dtype) {
throw Error("relay.einsum requires all tensors have the same dtype");
}
input_shapes.push_back(e->shape);
}

// Calculate output shape
Array<IndexExpr> oshape = topi::NumpyEinsumShape(param->equation, input_shapes);

auto rtype = TensorType(oshape, dtype);
reporter->Assign(types[1], rtype);
return true;
}

Array<te::Tensor> EinsumCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const EinsumAttrs* param = attrs.as<EinsumAttrs>();
ICHECK(param != nullptr);
return Array<te::Tensor>{topi::einsum(param->equation, inputs)};
}

Expr MakeEinsum(Expr data, String equation) {
auto attrs = make_object<EinsumAttrs>();
attrs->equation = std::move(equation);
static const Op& op = Op::Get("einsum");
return Call(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.einsum").set_body_typed(MakeEinsum);

RELAY_REGISTER_OP("einsum")
.describe(R"doc(Evaluates the Einstein summation convention
on the operands)doc" TVM_ADD_FILELINE)
.set_attrs_type<EinsumAttrs>()
.set_num_inputs(1)
.add_argument("data", "Tuple of Tensors", "The input list of tensors.")
.set_support_level(11)
.add_type_rel("Einsum", EinsumRel)
.set_attr<FTVMCompute>("FTVMCompute", EinsumCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

anwang2009 marked this conversation as resolved.
Show resolved Hide resolved
// Adapter function to make int array.
Array<Integer> GetIntArray(Array<IndexExpr> arr) {
for (size_t i = 0; i < arr.size(); ++i) {
Expand Down
5 changes: 0 additions & 5 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4739,11 +4739,6 @@ def verify_eyelike(indata):
"test_dropout_default_mask",
"test_dropout_default_mask_ratio",
"test_dropout_default_ratio",
"test_einsum_batch_diagonal",
"test_einsum_batch_matmul",
"test_einsum_inner_prod",
"test_einsum_sum",
"test_einsum_transpose",
"test_greater_equal",
"test_greater_equal_bcast",
"test_if_seq",
Expand Down