diff --git a/src/target/source/intrin_rule_opencl.cc b/src/target/source/intrin_rule_opencl.cc index 288bb2cfc069..64a50c3c84b1 100644 --- a/src/target/source/intrin_rule_opencl.cc +++ b/src/target/source/intrin_rule_opencl.cc @@ -49,6 +49,9 @@ TVM_REGISTER_OP("tir.round") TVM_REGISTER_OP("tir.exp").set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); +TVM_REGISTER_OP("tir.erf").set_attr("opencl.FLowerIntrinsic", + DispatchPureExtern); + TVM_REGISTER_OP("tir.exp2") .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); diff --git a/tests/python/unittest/test_target_codegen_opencl.py b/tests/python/unittest/test_target_codegen_opencl.py index 98340f0e6ac5..56392ec8cccc 100644 --- a/tests/python/unittest/test_target_codegen_opencl.py +++ b/tests/python/unittest/test_target_codegen_opencl.py @@ -17,6 +17,7 @@ import tvm from tvm import te import tvm.testing +import re target = "opencl" @@ -120,6 +121,25 @@ def check_max(dev, n, dtype): check_max(dev, 1, "float64") +def test_opencl_erf(): + def check_erf(dev, n, dtype): + A = te.placeholder((n,), name="A", dtype=dtype) + C = te.compute(A.shape, lambda *i: te.erf(A(*i)), name="C") + s = te.create_schedule(C.op) + s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x")) + fun = tvm.build(s, [A, C], target) + source_str = fun.imported_modules[0].get_source() + matches = re.findall("erf", source_str) + error_matches = re.findall("erff", source_str) + assert len(matches) == 1 and len(error_matches) == 0 + + dev = tvm.device(target, 0) + + check_erf(dev, 1, "float32") + check_erf(dev, 1, "float64") + + if __name__ == "__main__": test_opencl_ternary_expression() test_opencl_inf_nan() + test_opencl_erf()