From 254563a3140cf63fe77a46058688209de3aa213c Mon Sep 17 00:00:00 2001 From: zackcquic <82788137+zackcquic@users.noreply.github.com> Date: Fri, 7 May 2021 23:29:02 +0800 Subject: [PATCH] [RELAY] Enable registering op with python (#8002) Add a new API register_op Note: Implementing a op by pure python is still limited: 1. Custom type relation (add_type_rel()) is still not available in python. 2. Setting number inputs (set_num_inputs()) needs plevel > 128 in python. (see tests/python/relay/test_ir_op.py) --- python/tvm/ir/__init__.py | 2 +- python/tvm/ir/op.py | 12 ++++++++++++ src/ir/op.cc | 6 ++++++ tests/python/relay/test_ir_op.py | 16 ++++++++++++++-- 4 files changed, 33 insertions(+), 3 deletions(-) diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index a12d3e9855f0..4bc7f1ae4468 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -23,7 +23,7 @@ from .tensor_type import TensorType from .type_relation import TypeCall, TypeRelation from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range -from .op import Op, register_op_attr, register_intrin_lowering +from .op import Op, register_op, register_op_attr, register_intrin_lowering from .function import CallingConv, BaseFunc from .adt import Constructor, TypeData from .module import IRModule diff --git a/python/tvm/ir/op.py b/python/tvm/ir/op.py index 88e760ef91a1..b4cbd5563cda 100644 --- a/python/tvm/ir/op.py +++ b/python/tvm/ir/op.py @@ -86,6 +86,18 @@ def reset_attr(self, attr_name): _ffi_api.OpResetAttr(self, attr_name) +def register_op(op_name): + """Register an operator by name + + Parameters + ---------- + op_name : str + The name of new operator + """ + + _ffi_api.RegisterOp(op_name) + + def register_op_attr(op_name, attr_key, value=None, level=10): """Register an operator property of an operator by name. diff --git a/src/ir/op.cc b/src/ir/op.cc index 5b258ed2f2f0..8fd34d30ffa7 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -102,6 +102,12 @@ TVM_REGISTER_GLOBAL("ir.OpResetAttr").set_body_typed([](Op op, String attr_name) reg.reset_attr(attr_name); }); +TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name) { + const OpRegEntry* reg = OpRegistry::Global()->Get(op_name); + ICHECK(reg == nullptr) << "AttributeError: Operator " << op_name << " is registered before"; + OpRegistry::Global()->RegisterOrGet(op_name).set_name(); +}); + TVM_REGISTER_GLOBAL("ir.RegisterOpAttr") .set_body_typed([](String op_name, String attr_key, runtime::TVMArgValue value, int plevel) { auto& reg = OpRegistry::Global()->RegisterOrGet(op_name).set_name(); diff --git a/tests/python/relay/test_ir_op.py b/tests/python/relay/test_ir_op.py index 34c000017d12..fe559697348b 100644 --- a/tests/python/relay/test_ir_op.py +++ b/tests/python/relay/test_ir_op.py @@ -32,7 +32,7 @@ def test(x): def test_op_reset_attr(): - """ Tests reset_attr functionality. """ + """Tests reset_attr functionality.""" def add1(x): return x + 1 @@ -60,7 +60,7 @@ def add2(x): def test_op_temp_attr(): - """ Tests reset_attr functionality. """ + """Tests reset_attr functionality.""" def add1(x): return x + 1 @@ -99,9 +99,21 @@ def test_op_level3(): assert y.args[0] == x +def test_op_register(): + """Tests register_op functionality.""" + op_name = "custom_op" + + tvm.ir.register_op(op_name) + tvm.ir.register_op_attr(op_name, "num_inputs", 2, 256) + + assert tvm.ir.Op.get(op_name).name == op_name + assert tvm.ir.Op.get(op_name).num_inputs == 2 + + if __name__ == "__main__": test_op_attr() test_op_reset_attr() test_op_temp_attr() test_op_level1() test_op_level3() + test_op_register()