diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 9456ea80d860..a18d42902503 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -244,12 +244,18 @@ class OpRegEntry { runtime::TypedPackedFunc&, int, const Attrs&, const TypeReporter&)> type_rel_func); /*! - * \brief Set the the attrs type key and index to be AttrsType. + * \brief Set the attrs type key and index to be AttrsType. * \tparam AttrsType the attribute type to b set. * \return reference to self. */ template inline OpRegEntry& set_attrs_type(); + /*! + * \brief Set the attrs type key and index to be AttrsType. + * \param key The attribute type key to be set. + * \return reference to self. + */ + inline OpRegEntry& set_attrs_type_key(const String& key); /*! * \brief Set the num_inputs * \param n The number of inputs to be set. @@ -454,6 +460,12 @@ inline OpRegEntry& OpRegEntry::set_attrs_type() { // NOLINT(*) return *this; } +inline OpRegEntry& OpRegEntry::set_attrs_type_key(const String& key) { // NOLINT(*) + get()->attrs_type_key = key; + get()->attrs_type_index = Object::TypeKey2Index(key); + return *this; +} + inline OpRegEntry& OpRegEntry::set_support_level(int32_t n) { // NOLINT(*) get()->support_level = n; return *this; diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index 70c5988d6316..b4cc4421b169 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -23,7 +23,7 @@ from .tensor_type import TensorType from .type_relation import TypeCall, TypeRelation from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range -from .op import Op, register_op, register_op_attr, register_intrin_lowering +from .op import Op, register_op_attr, register_intrin_lowering from .function import CallingConv, BaseFunc from .adt import Constructor, TypeData from .module import IRModule diff --git a/python/tvm/ir/op.py b/python/tvm/ir/op.py index b4cbd5563cda..1a2854615f59 100644 --- a/python/tvm/ir/op.py +++ b/python/tvm/ir/op.py @@ -85,17 +85,71 @@ def reset_attr(self, attr_name): """ _ffi_api.OpResetAttr(self, attr_name) + def add_type_rel(self, rel_name, type_rel_func=None): + """Attach the type function corresponding to the return type. -def register_op(op_name): - """Register an operator by name + Parameters + ---------- + rel_name : str + The type relation name to register. + + type_rel_func : Optional[function (args: List[Type], attrs: Attrs) -> Type] + The backing relation function which can solve an arbitrary relation on variables. + Differences with type_rel_func in C++: + 1, when type_rel_func is not None: + 1) OpAddTypeRel on C++ side will adjust type_rel_func with TypeReporter to + calling convention of relay type system. + 2) type_rel_func returns output argument's type, return None means can't + infer output's type. + 3) only support single output operators for now, the last argument is output tensor. + 2, when type_rel_func is None, will call predefined type_rel_funcs in relay + accorrding to `tvm.relay.type_relation.` + rel_name. + """ + _ffi_api.OpAddTypeRel(self, rel_name, type_rel_func) - Parameters - ---------- - op_name : str - The name of new operator - """ + def add_argument(self, name, type, description): # pylint: disable=redefined-builtin + """Add arguments information to the function. - _ffi_api.RegisterOp(op_name) + Parameters + ---------- + name : str + The argument name. + type : str + The argument type. + description : str + The argument description. + """ + _ffi_api.OpAddArgument(self, name, type, description) + + def set_support_level(self, level): + """Set the support level of op. + + Parameters + ---------- + level : int + The support level. + """ + _ffi_api.OpSetSupportLevel(self, level) + + def set_num_inputs(self, n): + """Set the support level of op. + + Parameters + ---------- + n : int + The input number. + """ + _ffi_api.OpSetNumInputs(self, n) + + def set_attrs_type_key(self, key): + """Set the attribute type key of op. + + Parameters + ---------- + key : str + The type key. + """ + _ffi_api.OpSetAttrsTypeKey(self, key) def register_op_attr(op_name, attr_key, value=None, level=10): diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 5882027fb1d8..33cb46d67f34 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -21,6 +21,7 @@ from tvm.driver import lower, build from tvm.target import get_native_generic_func, GenericFunc from tvm.runtime import Object +import tvm.ir._ffi_api from . import _make @@ -40,6 +41,40 @@ def get(op_name): return tvm.ir.Op.get(op_name) +def register(op_name, describe=""): + """Get the Op for a given name. + when the op_name is not registered, create a new empty op with the given name. + when the op_name has been registered, abort with an error message. + + Parameters + ---------- + op_name : str + The operator name + + describe : Optional[str] + The operator description + """ + + tvm.ir._ffi_api.RegisterOp(op_name, describe) + + +def register_stateful(op_name, stateful, level=10): + """Register operator pattern for an op. + + Parameters + ---------- + op_name : str + The name of the op. + + stateful : bool + The stateful flag. + + level : int + The priority level + """ + tvm.ir.register_op_attr(op_name, "TOpIsStateful", stateful, level) + + class OpPattern(object): """Operator generic patterns diff --git a/src/ir/op.cc b/src/ir/op.cc index 8fd34d30ffa7..861545e6b959 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -102,10 +102,71 @@ TVM_REGISTER_GLOBAL("ir.OpResetAttr").set_body_typed([](Op op, String attr_name) reg.reset_attr(attr_name); }); -TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name) { +TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name, String descr) { const OpRegEntry* reg = OpRegistry::Global()->Get(op_name); ICHECK(reg == nullptr) << "AttributeError: Operator " << op_name << " is registered before"; - OpRegistry::Global()->RegisterOrGet(op_name).set_name(); + auto& op = OpRegistry::Global()->RegisterOrGet(op_name).set_name(); + op.describe(descr); +}); + +// This is exposed FFI api for prototyping using in python. +// Note: it is not full of the C++ type relation, +// since in python side we don't have access to the type reporter, +// and cannot propagate constraints to the inputs, only to the output. +TVM_REGISTER_GLOBAL("ir.OpAddTypeRel") + .set_body_typed([](Op op, String rel_name, runtime::TVMArgValue value) { + auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); + if (value.type_code() == kTVMPackedFuncHandle) { + // do an eager copy of the PackedFunc to avoid deleting function from frontend. + PackedFunc* fcopy = new PackedFunc(value.operator tvm::runtime::PackedFunc()); + auto f = [=](const Array& args, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) -> bool { + Array input_types(args.begin(), args.end() - 1); + // call customized relation functions + // *fcopy's signature: function (args: List[Type], attrs: Attrs) -> Type + Type ret_type = (*fcopy)(input_types, attrs); + // when defined ret_type, inference of output type is ok, do type assign + // otherwise, inference failure happens + if (ret_type.defined()) { + // the last argument is output + // TODO(xqdan): support multiple outputs + reporter->Assign(args.back(), ret_type); + return true; + } + return false; + }; + // adjust function call to call conventions of relay type system with TypeReporter + auto type_rel = runtime::TypedPackedFunc&, int, const Attrs&, + const TypeReporter&)>(f); + reg.add_type_rel(rel_name, type_rel); + } else if (value.type_code() == kTVMNullptr) { + // Call relation functions of relay + auto func_name = std::string("tvm.relay.type_relation.") + rel_name; + auto* f = runtime::Registry::Get(func_name); + ICHECK(f != nullptr) << "AddTypeRel error: no type_relation registered."; + reg.add_type_rel(rel_name, *f); + } + }); + +TVM_REGISTER_GLOBAL("ir.OpAddArgument") + .set_body_typed([](Op op, String name, String type, String description) { + auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); + reg.add_argument(name, type, description); + }); + +TVM_REGISTER_GLOBAL("ir.OpSetSupportLevel").set_body_typed([](Op op, int level) { + auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); + reg.set_support_level(level); +}); + +TVM_REGISTER_GLOBAL("ir.OpSetNumInputs").set_body_typed([](Op op, int n) { + auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); + reg.set_num_inputs(n); +}); + +TVM_REGISTER_GLOBAL("ir.OpSetAttrsTypeKey").set_body_typed([](Op op, String key) { + auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); + reg.set_attrs_type_key(key); }); TVM_REGISTER_GLOBAL("ir.RegisterOpAttr") diff --git a/tests/python/relay/test_ir_op.py b/tests/python/relay/test_ir_op.py index fe559697348b..edb8086dd426 100644 --- a/tests/python/relay/test_ir_op.py +++ b/tests/python/relay/test_ir_op.py @@ -17,6 +17,7 @@ import tvm from tvm import relay from tvm.relay.testing.temp_op_attr import TempOpAttr +from tvm.relay.op import op as _op def test_op_attr(): @@ -103,11 +104,20 @@ def test_op_register(): """Tests register_op functionality.""" op_name = "custom_op" - tvm.ir.register_op(op_name) - tvm.ir.register_op_attr(op_name, "num_inputs", 2, 256) - - assert tvm.ir.Op.get(op_name).name == op_name - assert tvm.ir.Op.get(op_name).num_inputs == 2 + _op.register(op_name, r"code(Add two tensor with inner broadcasting.)code") + _op.get(op_name).set_num_inputs(2) + _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.") + _op.get(op_name).add_argument("data_1", "Tensor", "The input data tensor.") + # call default relation functions + _op.get(op_name).add_type_rel("Identity") + _op.get(op_name).set_support_level(1) + _op.register_pattern(op_name, _op.OpPattern.ELEMWISE) + _op.register_stateful(op_name, False) + + assert _op.get(op_name).name == op_name + assert _op.get(op_name).num_inputs == 2 + assert _op.get(op_name).get_attr("TOpPattern") == _op.OpPattern.ELEMWISE + assert _op.get(op_name).get_attr("TOpIsStateful") == False if __name__ == "__main__": diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index e8179a37756c..a0d37844b837 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -22,6 +22,7 @@ from tvm import IRModule, te, relay, parser from tvm.relay import op, transform, analysis +from tvm.relay.op import op as _op def infer_mod(mod, annotate_spans=True): @@ -416,6 +417,134 @@ def test_dynamic_function(): assert mod["main"].params[0].checked_type == s_tt +def test_custom_op_infer(): + """Tests infer type for custom_op""" + op_name = "custom_log" + _op.register(op_name, r"code(cal log of a tensor.)code") + _op.get(op_name).set_num_inputs(1) + _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.") + # call default relation functions + _op.get(op_name).add_type_rel("Identity") + _op.get(op_name).set_support_level(1) + _op.register_pattern(op_name, _op.OpPattern.ELEMWISE) + _op.register_stateful(op_name, False) + + def clog(x): + return relay.Call(_op.get(op_name), [x]) + + tp = relay.TensorType((10, 10), "float32") + x = relay.var("x", tp) + sb = relay.ScopeBuilder() + t1 = sb.let("t1", clog(x)) + t2 = sb.let("t2", relay.add(t1, x)) + sb.ret(t2) + f = relay.Function([x], sb.get()) + fchecked = infer_expr(f) + assert fchecked.checked_type == relay.FuncType([tp], tp) + + +def test_custom_add_broadcast_op(): + """Tests infer type for broadcast custom_op""" + op_name = "custom_broadcast_add" + _op.register(op_name, r"code(Add two tensor with inner broadcasting.)code") + _op.get(op_name).set_num_inputs(2) + _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.") + _op.get(op_name).add_argument("data_1", "Tensor", "The input data tensor.") + # call default relation functions + _op.get(op_name).add_type_rel("Broadcast") + _op.get(op_name).set_support_level(1) + _op.register_stateful(op_name, False) + + def broadcast_add(x, y): + return relay.Call(_op.get(op_name), [x, y]) + + x = relay.var("x", shape=(10, 4)) + y = relay.var("y", shape=(5, 10, 1)) + z = broadcast_add(x, y) + func = relay.Function([x, y], z) + t1 = relay.TensorType((10, 4), "float32") + t2 = relay.TensorType((5, 10, 1), "float32") + t3 = relay.TensorType((5, 10, 4), "float32") + expected_ty = relay.FuncType([t1, t2], t3) + assert_has_type(func, expected_ty) + + +def test_custom_op_rel_infer(): + """Tests infer type for custom_op""" + + def custom_log1_rel(arg_types, attrs): + assert len(arg_types) == 1, "type relation arg number mismatch!" + if attrs: + assert isinstance(attrs, DictAttrs) + inputa_type = arg_types[0] + return relay.TensorType(inputa_type.shape, inputa_type.dtype) + + op_name = "custom_log1" + _op.register(op_name, r"code(cal log of a tensor.)code") + _op.get(op_name).set_num_inputs(1) + _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.") + _op.get(op_name).set_attrs_type_key("DictAttrs") + # call customized relation functions + _op.get(op_name).add_type_rel("custom_log1", custom_log1_rel) + _op.get(op_name).set_support_level(1) + _op.register_pattern(op_name, _op.OpPattern.ELEMWISE) + _op.register_stateful(op_name, False) + + def clog(x): + return relay.Call(_op.get(op_name), [x]) + + tp = relay.TensorType((10, 10), "float32") + x = relay.var("x", tp) + sb = relay.ScopeBuilder() + t1 = sb.let("t1", clog(x)) + t2 = sb.let("t2", relay.add(t1, x)) + sb.ret(t2) + f = relay.Function([x], sb.get()) + fchecked = infer_expr(f) + assert fchecked.checked_type == relay.FuncType([tp], tp) + + +def test_custom_op_rel_infer_exception(): + """Tests infer type for custom_op""" + + def custom_log1_rel(arg_types, attrs): + assert len(arg_types) == 2, "type relation arg number mismatch!" + return None + + op_name = "custom_log2" + _op.register(op_name, r"code(cal log of a tensor.)code") + _op.get(op_name).set_num_inputs(1) + _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.") + _op.get(op_name).set_attrs_type_key("DictAttrs") + # call customized relation functions + _op.get(op_name).add_type_rel("custom_log2", custom_log1_rel) + _op.get(op_name).set_support_level(1) + _op.register_pattern(op_name, _op.OpPattern.ELEMWISE) + _op.register_stateful(op_name, False) + + def clog(x): + return relay.Call(_op.get(op_name), [x]) + + tp = relay.TensorType((10, 10), "float32") + x = relay.var("x", tp) + sb = relay.ScopeBuilder() + t1 = sb.let("t1", clog(x)) + t2 = sb.let("t2", relay.add(t1, x)) + sb.ret(t2) + f = relay.Function([x], sb.get()) + with pytest.raises(tvm.error.TVMError) as cm: + fchecked = infer_expr(f) + assert "type relation arg number mismatch" in str(cm.execption) + + +def test_repeat_register(): + op_name = "custom_log3" + _op.register(op_name, r"code(cal log of a tensor.)code") + with pytest.raises(tvm.error.TVMError) as cm: + _op.register(op_name) + assert "Operator custom_log3 is registered before" in str(cm.execption) + + if __name__ == "__main__": import sys