Skip to content

Commit

Permalink
Attach hook using tvm.register_func directly
Browse files Browse the repository at this point in the history
This is both cleaner and allows using the `override` function parameter to allow overwriting.
  • Loading branch information
Mousius committed Jul 22, 2021
1 parent 011ebc9 commit b5b98eb
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions tests/python/relay/test_additional_target_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)


def translate_relay_add_to_tir_subtract():
def translate_relay_add_to_tir_subtract(relay_func, target):
"""A transform to test Relay -> TIR with"""
ib = tvm.tir.ir_builder.create()
A = tvm.tir.decl_buffer(
Expand Down Expand Up @@ -67,9 +67,7 @@ def translate_relay_add_to_tir_subtract():
"check_result", [check_vm_result, check_graph_executor_result, check_aot_executor_result]
)
def test_tir_external_generation(check_result):
@tvm.register_func("target.test.tir_lowering")
def relay_to_tir(expr, target):
return translate_relay_add_to_tir_subtract(expr)
tvm.register_func("target.test.tir_lowering", translate_relay_add_to_tir_subtract, True)

shape = (8, 8)
x_data = np.random.randint(255, size=shape).astype("float32")
Expand Down

0 comments on commit b5b98eb

Please sign in to comment.