Skip to content

Commit

Permalink
Complete register op from python (apache#8079)
Browse files Browse the repository at this point in the history
* Complete register op from python

* fix lint

* fix lint

* fix lint

* fix comments

* fix

* fix

* fix comments

* fix lint

* fix lint

* add comments

* fix build

* fix

* add exception case

* fix

* fix comments

* fix

* fix

* fix

* fix

* fix

* fix

* fix

Co-authored-by: xiaoqiang.dan <xiaoqiang.dan@streamcoputing.com>
  • Loading branch information
2 people authored and trevor-m committed Jun 17, 2021
1 parent e57780c commit 7ba0941
Show file tree
Hide file tree
Showing 7 changed files with 318 additions and 17 deletions.
14 changes: 13 additions & 1 deletion include/tvm/ir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,18 @@ class OpRegEntry {
runtime::TypedPackedFunc<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>
type_rel_func);
/*!
* \brief Set the the attrs type key and index to be AttrsType.
* \brief Set the attrs type key and index to be AttrsType.
* \tparam AttrsType the attribute type to b set.
* \return reference to self.
*/
template <typename AttrsType>
inline OpRegEntry& set_attrs_type();
/*!
* \brief Set the attrs type key and index to be AttrsType.
* \param key The attribute type key to be set.
* \return reference to self.
*/
inline OpRegEntry& set_attrs_type_key(const String& key);
/*!
* \brief Set the num_inputs
* \param n The number of inputs to be set.
Expand Down Expand Up @@ -454,6 +460,12 @@ inline OpRegEntry& OpRegEntry::set_attrs_type() { // NOLINT(*)
return *this;
}

inline OpRegEntry& OpRegEntry::set_attrs_type_key(const String& key) { // NOLINT(*)
get()->attrs_type_key = key;
get()->attrs_type_index = Object::TypeKey2Index(key);
return *this;
}

inline OpRegEntry& OpRegEntry::set_support_level(int32_t n) { // NOLINT(*)
get()->support_level = n;
return *this;
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .tensor_type import TensorType
from .type_relation import TypeCall, TypeRelation
from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range
from .op import Op, register_op, register_op_attr, register_intrin_lowering
from .op import Op, register_op_attr, register_intrin_lowering
from .function import CallingConv, BaseFunc
from .adt import Constructor, TypeData
from .module import IRModule
Expand Down
70 changes: 62 additions & 8 deletions python/tvm/ir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,71 @@ def reset_attr(self, attr_name):
"""
_ffi_api.OpResetAttr(self, attr_name)

def add_type_rel(self, rel_name, type_rel_func=None):
"""Attach the type function corresponding to the return type.
def register_op(op_name):
"""Register an operator by name
Parameters
----------
rel_name : str
The type relation name to register.
type_rel_func : Optional[function (args: List[Type], attrs: Attrs) -> Type]
The backing relation function which can solve an arbitrary relation on variables.
Differences with type_rel_func in C++:
1, when type_rel_func is not None:
1) OpAddTypeRel on C++ side will adjust type_rel_func with TypeReporter to
calling convention of relay type system.
2) type_rel_func returns output argument's type, return None means can't
infer output's type.
3) only support single output operators for now, the last argument is output tensor.
2, when type_rel_func is None, will call predefined type_rel_funcs in relay
accorrding to `tvm.relay.type_relation.` + rel_name.
"""
_ffi_api.OpAddTypeRel(self, rel_name, type_rel_func)

Parameters
----------
op_name : str
The name of new operator
"""
def add_argument(self, name, type, description): # pylint: disable=redefined-builtin
"""Add arguments information to the function.
_ffi_api.RegisterOp(op_name)
Parameters
----------
name : str
The argument name.
type : str
The argument type.
description : str
The argument description.
"""
_ffi_api.OpAddArgument(self, name, type, description)

def set_support_level(self, level):
"""Set the support level of op.
Parameters
----------
level : int
The support level.
"""
_ffi_api.OpSetSupportLevel(self, level)

def set_num_inputs(self, n):
"""Set the support level of op.
Parameters
----------
n : int
The input number.
"""
_ffi_api.OpSetNumInputs(self, n)

def set_attrs_type_key(self, key):
"""Set the attribute type key of op.
Parameters
----------
key : str
The type key.
"""
_ffi_api.OpSetAttrsTypeKey(self, key)


def register_op_attr(op_name, attr_key, value=None, level=10):
Expand Down
35 changes: 35 additions & 0 deletions python/tvm/relay/op/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tvm.driver import lower, build
from tvm.target import get_native_generic_func, GenericFunc
from tvm.runtime import Object
import tvm.ir._ffi_api
from . import _make


Expand All @@ -40,6 +41,40 @@ def get(op_name):
return tvm.ir.Op.get(op_name)


def register(op_name, describe=""):
"""Get the Op for a given name.
when the op_name is not registered, create a new empty op with the given name.
when the op_name has been registered, abort with an error message.
Parameters
----------
op_name : str
The operator name
describe : Optional[str]
The operator description
"""

tvm.ir._ffi_api.RegisterOp(op_name, describe)


def register_stateful(op_name, stateful, level=10):
"""Register operator pattern for an op.
Parameters
----------
op_name : str
The name of the op.
stateful : bool
The stateful flag.
level : int
The priority level
"""
tvm.ir.register_op_attr(op_name, "TOpIsStateful", stateful, level)


class OpPattern(object):
"""Operator generic patterns
Expand Down
65 changes: 63 additions & 2 deletions src/ir/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,71 @@ TVM_REGISTER_GLOBAL("ir.OpResetAttr").set_body_typed([](Op op, String attr_name)
reg.reset_attr(attr_name);
});

TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name) {
TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name, String descr) {
const OpRegEntry* reg = OpRegistry::Global()->Get(op_name);
ICHECK(reg == nullptr) << "AttributeError: Operator " << op_name << " is registered before";
OpRegistry::Global()->RegisterOrGet(op_name).set_name();
auto& op = OpRegistry::Global()->RegisterOrGet(op_name).set_name();
op.describe(descr);
});

// This is exposed FFI api for prototyping using in python.
// Note: it is not full of the C++ type relation,
// since in python side we don't have access to the type reporter,
// and cannot propagate constraints to the inputs, only to the output.
TVM_REGISTER_GLOBAL("ir.OpAddTypeRel")
.set_body_typed([](Op op, String rel_name, runtime::TVMArgValue value) {
auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
if (value.type_code() == kTVMPackedFuncHandle) {
// do an eager copy of the PackedFunc to avoid deleting function from frontend.
PackedFunc* fcopy = new PackedFunc(value.operator tvm::runtime::PackedFunc());
auto f = [=](const Array<Type>& args, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) -> bool {
Array<Type> input_types(args.begin(), args.end() - 1);
// call customized relation functions
// *fcopy's signature: function (args: List[Type], attrs: Attrs) -> Type
Type ret_type = (*fcopy)(input_types, attrs);
// when defined ret_type, inference of output type is ok, do type assign
// otherwise, inference failure happens
if (ret_type.defined()) {
// the last argument is output
// TODO(xqdan): support multiple outputs
reporter->Assign(args.back(), ret_type);
return true;
}
return false;
};
// adjust function call to call conventions of relay type system with TypeReporter
auto type_rel = runtime::TypedPackedFunc<bool(const Array<Type>&, int, const Attrs&,
const TypeReporter&)>(f);
reg.add_type_rel(rel_name, type_rel);
} else if (value.type_code() == kTVMNullptr) {
// Call relation functions of relay
auto func_name = std::string("tvm.relay.type_relation.") + rel_name;
auto* f = runtime::Registry::Get(func_name);
ICHECK(f != nullptr) << "AddTypeRel error: no type_relation registered.";
reg.add_type_rel(rel_name, *f);
}
});

TVM_REGISTER_GLOBAL("ir.OpAddArgument")
.set_body_typed([](Op op, String name, String type, String description) {
auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
reg.add_argument(name, type, description);
});

TVM_REGISTER_GLOBAL("ir.OpSetSupportLevel").set_body_typed([](Op op, int level) {
auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
reg.set_support_level(level);
});

TVM_REGISTER_GLOBAL("ir.OpSetNumInputs").set_body_typed([](Op op, int n) {
auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
reg.set_num_inputs(n);
});

TVM_REGISTER_GLOBAL("ir.OpSetAttrsTypeKey").set_body_typed([](Op op, String key) {
auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
reg.set_attrs_type_key(key);
});

TVM_REGISTER_GLOBAL("ir.RegisterOpAttr")
Expand Down
20 changes: 15 additions & 5 deletions tests/python/relay/test_ir_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tvm
from tvm import relay
from tvm.relay.testing.temp_op_attr import TempOpAttr
from tvm.relay.op import op as _op


def test_op_attr():
Expand Down Expand Up @@ -103,11 +104,20 @@ def test_op_register():
"""Tests register_op functionality."""
op_name = "custom_op"

tvm.ir.register_op(op_name)
tvm.ir.register_op_attr(op_name, "num_inputs", 2, 256)

assert tvm.ir.Op.get(op_name).name == op_name
assert tvm.ir.Op.get(op_name).num_inputs == 2
_op.register(op_name, r"code(Add two tensor with inner broadcasting.)code")
_op.get(op_name).set_num_inputs(2)
_op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
_op.get(op_name).add_argument("data_1", "Tensor", "The input data tensor.")
# call default relation functions
_op.get(op_name).add_type_rel("Identity")
_op.get(op_name).set_support_level(1)
_op.register_pattern(op_name, _op.OpPattern.ELEMWISE)
_op.register_stateful(op_name, False)

assert _op.get(op_name).name == op_name
assert _op.get(op_name).num_inputs == 2
assert _op.get(op_name).get_attr("TOpPattern") == _op.OpPattern.ELEMWISE
assert _op.get(op_name).get_attr("TOpIsStateful") == False


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 7ba0941

Please sign in to comment.