diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 186071f22764..87eea97a8b04 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -350,7 +350,8 @@ def te_gelu_tanh(x: te.Tensor): tir.const(1.0, dtype) + topi.tanh( tir.const(math.sqrt(2.0 / math.pi), dtype) - * (x + tir.const(0.044715, dtype) * topi.power(x, 3)) + * x + * (1 + tir.const(0.044715, dtype) * x * x) ) ) ) diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index fd14f4892154..9f35f73a62cf 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -688,6 +688,32 @@ TVM_REGISTER_GLOBAL("tir.bitwise_not").set_body_typed([](PrimExpr a, Span span) PrimExpr pow(PrimExpr x, PrimExpr y, Span span) { BinaryOpMatchTypes(x, y, span); ICHECK(x.dtype().is_float()) << "power only applies to float"; + + // If we detect pow(x, 3), suggest using x * x * x + if (y.dtype().is_int()) { + using tir::IntImmNode; + const IntImmNode* px = y.as(); + if (px) { + if (px->value >= 3) { + LOG(WARNING) + << "Detected pow(x, y) where y >= 3, it is recommended to avoid this as it may lead to " + "uninteded behaviors when x < 0. Perhaps with `x * x * x ...` or " + "`pow(x, 2) * pow(x, 2) ...`."; + } + } + } else if (y.dtype().is_float()) { + using tir::FloatImmNode; + const FloatImmNode* fx = y.as(); + if (fx) { + if (fx->value >= 3.0) { + LOG(WARNING) + << "Detected pow(x, y) where y >= 3, it is recommended to avoid this as it may lead to " + "uninteded behaviors when x < 0. Perhaps with `x * x * x ...` or " + "`pow(x, 2) * pow(x, 2) ...`."; + } + } + } + static auto op = Op::Get("tir.pow"); return tir::Call(x.dtype(), op, {x, y}, span); } diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 07fbc3419b98..45e6bd878a95 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -1259,10 +1259,11 @@ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float3 def gelu_tanh(A: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_multiply: T.Buffer((T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) T_multiply_1 = T.alloc_buffer((T.int64(2), T.int64(3))) - T_power = T.alloc_buffer((T.int64(2), T.int64(3))) T_multiply_2 = T.alloc_buffer((T.int64(2), T.int64(3))) - T_add = T.alloc_buffer((T.int64(2), T.int64(3))) T_multiply_3 = T.alloc_buffer((T.int64(2), T.int64(3))) + T_multiply_4 = T.alloc_buffer((T.int64(2), T.int64(3))) + T_add = T.alloc_buffer((T.int64(2), T.int64(3))) + T_multiply_5 = T.alloc_buffer((T.int64(2), T.int64(3))) compute = T.alloc_buffer((T.int64(2), T.int64(3))) T_add_1 = T.alloc_buffer((T.int64(2), T.int64(3))) for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): @@ -1272,35 +1273,41 @@ def gelu_tanh(A: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_multiply: T.Bu T.writes(T_multiply_1[v_ax0, v_ax1]) T_multiply_1[v_ax0, v_ax1] = T.float32(0.5) * A[v_ax0, v_ax1] for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_power"): + with T.block("T_multiply_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1]) - T.writes(T_power[v_ax0, v_ax1]) - T_power[v_ax0, v_ax1] = T.pow(A[v_ax0, v_ax1], T.float32(3)) + T.writes(T_multiply_2[v_ax0, v_ax1]) + T_multiply_2[v_ax0, v_ax1] = T.float32(0.79788456080286541) * A[v_ax0, v_ax1] for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_multiply_1"): + with T.block("T_multiply_2"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(T_power[v_ax0, v_ax1]) - T.writes(T_multiply_2[v_ax0, v_ax1]) - T_multiply_2[v_ax0, v_ax1] = T.float32(0.044714999999999998) * T_power[v_ax0, v_ax1] + T.reads(A[v_ax0, v_ax1]) + T.writes(T_multiply_3[v_ax0, v_ax1]) + T_multiply_3[v_ax0, v_ax1] = T.float32(0.044714999999999998) * A[v_ax0, v_ax1] + for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_multiply_3"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(T_multiply_3[v_ax0, v_ax1], A[v_ax0, v_ax1]) + T.writes(T_multiply_4[v_ax0, v_ax1]) + T_multiply_4[v_ax0, v_ax1] = T_multiply_3[v_ax0, v_ax1] * A[v_ax0, v_ax1] for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): with T.block("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(A[v_ax0, v_ax1], T_multiply_2[v_ax0, v_ax1]) + T.reads(T_multiply_4[v_ax0, v_ax1]) T.writes(T_add[v_ax0, v_ax1]) - T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + T_multiply_2[v_ax0, v_ax1] + T_add[v_ax0, v_ax1] = T.float32(1) + T_multiply_4[v_ax0, v_ax1] for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_multiply_2"): + with T.block("T_multiply_4"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(T_add[v_ax0, v_ax1]) - T.writes(T_multiply_3[v_ax0, v_ax1]) - T_multiply_3[v_ax0, v_ax1] = T.float32(0.79788456080286541) * T_add[v_ax0, v_ax1] + T.reads(T_multiply_2[v_ax0, v_ax1], T_add[v_ax0, v_ax1]) + T.writes(T_multiply_5[v_ax0, v_ax1]) + T_multiply_5[v_ax0, v_ax1] = T_multiply_2[v_ax0, v_ax1] * T_add[v_ax0, v_ax1] for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.block("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(T_multiply_3[v_i0, v_i1]) + T.reads(T_multiply_5[v_i0, v_i1]) T.writes(compute[v_i0, v_i1]) - compute[v_i0, v_i1] = T.tanh(T_multiply_3[v_i0, v_i1]) + compute[v_i0, v_i1] = T.tanh(T_multiply_5[v_i0, v_i1]) for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): with T.block("T_add_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -1308,7 +1315,7 @@ def gelu_tanh(A: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_multiply: T.Bu T.writes(T_add_1[v_ax0, v_ax1]) T_add_1[v_ax0, v_ax1] = T.float32(1) + compute[v_ax0, v_ax1] for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_multiply_3"): + with T.block("T_multiply_5"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_1[v_ax0, v_ax1], T_add_1[v_ax0, v_ax1]) T.writes(T_multiply[v_ax0, v_ax1]) @@ -1344,11 +1351,13 @@ def gelu_tanh(var_A: T.handle, var_T_multiply: T.handle): m, n = T.int64(), T.int64() A = T.match_buffer(var_A, (m, n)) T_multiply = T.match_buffer(var_T_multiply, (m, n)) + # with T.block("root"): T_multiply_1 = T.alloc_buffer((m, n)) - T_power = T.alloc_buffer((m, n)) T_multiply_2 = T.alloc_buffer((m, n)) - T_add = T.alloc_buffer((m, n)) T_multiply_3 = T.alloc_buffer((m, n)) + T_multiply_4 = T.alloc_buffer((m, n)) + T_add = T.alloc_buffer((m, n)) + T_multiply_5 = T.alloc_buffer((m, n)) compute = T.alloc_buffer((m, n)) T_add_1 = T.alloc_buffer((m, n)) for ax0, ax1 in T.grid(m, n): @@ -1358,35 +1367,41 @@ def gelu_tanh(var_A: T.handle, var_T_multiply: T.handle): T.writes(T_multiply_1[v_ax0, v_ax1]) T_multiply_1[v_ax0, v_ax1] = T.float32(0.5) * A[v_ax0, v_ax1] for ax0, ax1 in T.grid(m, n): - with T.block("T_power"): + with T.block("T_multiply_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1]) - T.writes(T_power[v_ax0, v_ax1]) - T_power[v_ax0, v_ax1] = T.pow(A[v_ax0, v_ax1], T.float32(3)) + T.writes(T_multiply_2[v_ax0, v_ax1]) + T_multiply_2[v_ax0, v_ax1] = T.float32(0.79788456080286541) * A[v_ax0, v_ax1] for ax0, ax1 in T.grid(m, n): - with T.block("T_multiply_1"): + with T.block("T_multiply_2"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(T_power[v_ax0, v_ax1]) - T.writes(T_multiply_2[v_ax0, v_ax1]) - T_multiply_2[v_ax0, v_ax1] = T.float32(0.044714999999999998) * T_power[v_ax0, v_ax1] + T.reads(A[v_ax0, v_ax1]) + T.writes(T_multiply_3[v_ax0, v_ax1]) + T_multiply_3[v_ax0, v_ax1] = T.float32(0.044714999999999998) * A[v_ax0, v_ax1] + for ax0, ax1 in T.grid(m, n): + with T.block("T_multiply_3"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(T_multiply_3[v_ax0, v_ax1], A[v_ax0, v_ax1]) + T.writes(T_multiply_4[v_ax0, v_ax1]) + T_multiply_4[v_ax0, v_ax1] = T_multiply_3[v_ax0, v_ax1] * A[v_ax0, v_ax1] for ax0, ax1 in T.grid(m, n): with T.block("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(A[v_ax0, v_ax1], T_multiply_2[v_ax0, v_ax1]) + T.reads(T_multiply_4[v_ax0, v_ax1]) T.writes(T_add[v_ax0, v_ax1]) - T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + T_multiply_2[v_ax0, v_ax1] + T_add[v_ax0, v_ax1] = T.float32(1) + T_multiply_4[v_ax0, v_ax1] for ax0, ax1 in T.grid(m, n): - with T.block("T_multiply_2"): + with T.block("T_multiply_4"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(T_add[v_ax0, v_ax1]) - T.writes(T_multiply_3[v_ax0, v_ax1]) - T_multiply_3[v_ax0, v_ax1] = T.float32(0.79788456080286541) * T_add[v_ax0, v_ax1] + T.reads(T_multiply_2[v_ax0, v_ax1], T_add[v_ax0, v_ax1]) + T.writes(T_multiply_5[v_ax0, v_ax1]) + T_multiply_5[v_ax0, v_ax1] = T_multiply_2[v_ax0, v_ax1] * T_add[v_ax0, v_ax1] for i0, i1 in T.grid(m, n): with T.block("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(T_multiply_3[v_i0, v_i1]) + T.reads(T_multiply_5[v_i0, v_i1]) T.writes(compute[v_i0, v_i1]) - compute[v_i0, v_i1] = T.tanh(T_multiply_3[v_i0, v_i1]) + compute[v_i0, v_i1] = T.tanh(T_multiply_5[v_i0, v_i1]) for ax0, ax1 in T.grid(m, n): with T.block("T_add_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -1394,7 +1409,7 @@ def gelu_tanh(var_A: T.handle, var_T_multiply: T.handle): T.writes(T_add_1[v_ax0, v_ax1]) T_add_1[v_ax0, v_ax1] = T.float32(1) + compute[v_ax0, v_ax1] for ax0, ax1 in T.grid(m, n): - with T.block("T_multiply_3"): + with T.block("T_multiply_5"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_1[v_ax0, v_ax1], T_add_1[v_ax0, v_ax1]) T.writes(T_multiply[v_ax0, v_ax1])