Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Use x*x*x instead of pow(x,3) #16518

Merged
merged 2 commits into from
Feb 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ def te_gelu_tanh(x: te.Tensor):
tir.const(1.0, dtype)
+ topi.tanh(
tir.const(math.sqrt(2.0 / math.pi), dtype)
* (x + tir.const(0.044715, dtype) * topi.power(x, 3))
* x
* (1 + tir.const(0.044715, dtype) * x * x)
)
)
)
Expand Down
26 changes: 26 additions & 0 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,32 @@ TVM_REGISTER_GLOBAL("tir.bitwise_not").set_body_typed([](PrimExpr a, Span span)
PrimExpr pow(PrimExpr x, PrimExpr y, Span span) {
BinaryOpMatchTypes(x, y, span);
ICHECK(x.dtype().is_float()) << "power only applies to float";

// If we detect pow(x, 3), suggest using x * x * x
if (y.dtype().is_int()) {
using tir::IntImmNode;
const IntImmNode* px = y.as<IntImmNode>();
if (px) {
if (px->value >= 3) {
LOG(WARNING)
<< "Detected pow(x, y) where y >= 3, it is recommended to avoid this as it may lead to "
"uninteded behaviors when x < 0. Perhaps with `x * x * x ...` or "
"`pow(x, 2) * pow(x, 2) ...`.";
}
}
} else if (y.dtype().is_float()) {
using tir::FloatImmNode;
const FloatImmNode* fx = y.as<FloatImmNode>();
if (fx) {
if (fx->value >= 3.0) {
LOG(WARNING)
<< "Detected pow(x, y) where y >= 3, it is recommended to avoid this as it may lead to "
"uninteded behaviors when x < 0. Perhaps with `x * x * x ...` or "
"`pow(x, 2) * pow(x, 2) ...`.";
}
}
}

static auto op = Op::Get("tir.pow");
return tir::Call(x.dtype(), op, {x, y}, span);
}
Expand Down
87 changes: 51 additions & 36 deletions tests/python/relax/test_transform_legalize_ops_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,10 +1259,11 @@ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float3
def gelu_tanh(A: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_multiply: T.Buffer((T.int64(2), T.int64(3)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
T_multiply_1 = T.alloc_buffer((T.int64(2), T.int64(3)))
T_power = T.alloc_buffer((T.int64(2), T.int64(3)))
T_multiply_2 = T.alloc_buffer((T.int64(2), T.int64(3)))
T_add = T.alloc_buffer((T.int64(2), T.int64(3)))
T_multiply_3 = T.alloc_buffer((T.int64(2), T.int64(3)))
T_multiply_4 = T.alloc_buffer((T.int64(2), T.int64(3)))
T_add = T.alloc_buffer((T.int64(2), T.int64(3)))
T_multiply_5 = T.alloc_buffer((T.int64(2), T.int64(3)))
compute = T.alloc_buffer((T.int64(2), T.int64(3)))
T_add_1 = T.alloc_buffer((T.int64(2), T.int64(3)))
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
Expand All @@ -1272,43 +1273,49 @@ def gelu_tanh(A: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_multiply: T.Bu
T.writes(T_multiply_1[v_ax0, v_ax1])
T_multiply_1[v_ax0, v_ax1] = T.float32(0.5) * A[v_ax0, v_ax1]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_power"):
with T.block("T_multiply_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1])
T.writes(T_power[v_ax0, v_ax1])
T_power[v_ax0, v_ax1] = T.pow(A[v_ax0, v_ax1], T.float32(3))
T.writes(T_multiply_2[v_ax0, v_ax1])
T_multiply_2[v_ax0, v_ax1] = T.float32(0.79788456080286541) * A[v_ax0, v_ax1]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_multiply_1"):
with T.block("T_multiply_2"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_power[v_ax0, v_ax1])
T.writes(T_multiply_2[v_ax0, v_ax1])
T_multiply_2[v_ax0, v_ax1] = T.float32(0.044714999999999998) * T_power[v_ax0, v_ax1]
T.reads(A[v_ax0, v_ax1])
T.writes(T_multiply_3[v_ax0, v_ax1])
T_multiply_3[v_ax0, v_ax1] = T.float32(0.044714999999999998) * A[v_ax0, v_ax1]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_multiply_3"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_multiply_3[v_ax0, v_ax1], A[v_ax0, v_ax1])
T.writes(T_multiply_4[v_ax0, v_ax1])
T_multiply_4[v_ax0, v_ax1] = T_multiply_3[v_ax0, v_ax1] * A[v_ax0, v_ax1]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1], T_multiply_2[v_ax0, v_ax1])
T.reads(T_multiply_4[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + T_multiply_2[v_ax0, v_ax1]
T_add[v_ax0, v_ax1] = T.float32(1) + T_multiply_4[v_ax0, v_ax1]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_multiply_2"):
with T.block("T_multiply_4"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_add[v_ax0, v_ax1])
T.writes(T_multiply_3[v_ax0, v_ax1])
T_multiply_3[v_ax0, v_ax1] = T.float32(0.79788456080286541) * T_add[v_ax0, v_ax1]
T.reads(T_multiply_2[v_ax0, v_ax1], T_add[v_ax0, v_ax1])
T.writes(T_multiply_5[v_ax0, v_ax1])
T_multiply_5[v_ax0, v_ax1] = T_multiply_2[v_ax0, v_ax1] * T_add[v_ax0, v_ax1]
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(T_multiply_3[v_i0, v_i1])
T.reads(T_multiply_5[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.tanh(T_multiply_3[v_i0, v_i1])
compute[v_i0, v_i1] = T.tanh(T_multiply_5[v_i0, v_i1])
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_add_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(compute[v_ax0, v_ax1])
T.writes(T_add_1[v_ax0, v_ax1])
T_add_1[v_ax0, v_ax1] = T.float32(1) + compute[v_ax0, v_ax1]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_multiply_3"):
with T.block("T_multiply_5"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_multiply_1[v_ax0, v_ax1], T_add_1[v_ax0, v_ax1])
T.writes(T_multiply[v_ax0, v_ax1])
Expand Down Expand Up @@ -1344,11 +1351,13 @@ def gelu_tanh(var_A: T.handle, var_T_multiply: T.handle):
m, n = T.int64(), T.int64()
A = T.match_buffer(var_A, (m, n))
T_multiply = T.match_buffer(var_T_multiply, (m, n))
# with T.block("root"):
T_multiply_1 = T.alloc_buffer((m, n))
T_power = T.alloc_buffer((m, n))
T_multiply_2 = T.alloc_buffer((m, n))
T_add = T.alloc_buffer((m, n))
T_multiply_3 = T.alloc_buffer((m, n))
T_multiply_4 = T.alloc_buffer((m, n))
T_add = T.alloc_buffer((m, n))
T_multiply_5 = T.alloc_buffer((m, n))
compute = T.alloc_buffer((m, n))
T_add_1 = T.alloc_buffer((m, n))
for ax0, ax1 in T.grid(m, n):
Expand All @@ -1358,43 +1367,49 @@ def gelu_tanh(var_A: T.handle, var_T_multiply: T.handle):
T.writes(T_multiply_1[v_ax0, v_ax1])
T_multiply_1[v_ax0, v_ax1] = T.float32(0.5) * A[v_ax0, v_ax1]
for ax0, ax1 in T.grid(m, n):
with T.block("T_power"):
with T.block("T_multiply_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1])
T.writes(T_power[v_ax0, v_ax1])
T_power[v_ax0, v_ax1] = T.pow(A[v_ax0, v_ax1], T.float32(3))
T.writes(T_multiply_2[v_ax0, v_ax1])
T_multiply_2[v_ax0, v_ax1] = T.float32(0.79788456080286541) * A[v_ax0, v_ax1]
for ax0, ax1 in T.grid(m, n):
with T.block("T_multiply_1"):
with T.block("T_multiply_2"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_power[v_ax0, v_ax1])
T.writes(T_multiply_2[v_ax0, v_ax1])
T_multiply_2[v_ax0, v_ax1] = T.float32(0.044714999999999998) * T_power[v_ax0, v_ax1]
T.reads(A[v_ax0, v_ax1])
T.writes(T_multiply_3[v_ax0, v_ax1])
T_multiply_3[v_ax0, v_ax1] = T.float32(0.044714999999999998) * A[v_ax0, v_ax1]
for ax0, ax1 in T.grid(m, n):
with T.block("T_multiply_3"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_multiply_3[v_ax0, v_ax1], A[v_ax0, v_ax1])
T.writes(T_multiply_4[v_ax0, v_ax1])
T_multiply_4[v_ax0, v_ax1] = T_multiply_3[v_ax0, v_ax1] * A[v_ax0, v_ax1]
for ax0, ax1 in T.grid(m, n):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1], T_multiply_2[v_ax0, v_ax1])
T.reads(T_multiply_4[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + T_multiply_2[v_ax0, v_ax1]
T_add[v_ax0, v_ax1] = T.float32(1) + T_multiply_4[v_ax0, v_ax1]
for ax0, ax1 in T.grid(m, n):
with T.block("T_multiply_2"):
with T.block("T_multiply_4"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_add[v_ax0, v_ax1])
T.writes(T_multiply_3[v_ax0, v_ax1])
T_multiply_3[v_ax0, v_ax1] = T.float32(0.79788456080286541) * T_add[v_ax0, v_ax1]
T.reads(T_multiply_2[v_ax0, v_ax1], T_add[v_ax0, v_ax1])
T.writes(T_multiply_5[v_ax0, v_ax1])
T_multiply_5[v_ax0, v_ax1] = T_multiply_2[v_ax0, v_ax1] * T_add[v_ax0, v_ax1]
for i0, i1 in T.grid(m, n):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(T_multiply_3[v_i0, v_i1])
T.reads(T_multiply_5[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.tanh(T_multiply_3[v_i0, v_i1])
compute[v_i0, v_i1] = T.tanh(T_multiply_5[v_i0, v_i1])
for ax0, ax1 in T.grid(m, n):
with T.block("T_add_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(compute[v_ax0, v_ax1])
T.writes(T_add_1[v_ax0, v_ax1])
T_add_1[v_ax0, v_ax1] = T.float32(1) + compute[v_ax0, v_ax1]
for ax0, ax1 in T.grid(m, n):
with T.block("T_multiply_3"):
with T.block("T_multiply_5"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_multiply_1[v_ax0, v_ax1], T_add_1[v_ax0, v_ax1])
T.writes(T_multiply[v_ax0, v_ax1])
Expand Down
Loading