diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index a6ad06e544a6..0d6c3ef58cdf 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -175,7 +175,7 @@ def fast_softmax_strategy(attrs, inputs, out_type, target): strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_softmax(topi.nn.fast_softmax), - naive_schedule, + wrap_topi_schedule(topi.generic.schedule_fast_softmax), name="fast_softmax.generic", ) return strategy diff --git a/python/tvm/topi/generic/nn.py b/python/tvm/topi/generic/nn.py index 866887706862..04d649037fef 100644 --- a/python/tvm/topi/generic/nn.py +++ b/python/tvm/topi/generic/nn.py @@ -563,6 +563,23 @@ def schedule_softmax(outs): return _default_schedule(outs, False) +def schedule_fast_softmax(outs): + """Schedule for fast_softmax + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of fast_softmax + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + def schedule_dense(outs): """Schedule for dense diff --git a/src/relay/transforms/fast_math.cc b/src/relay/transforms/fast_math.cc index 91fb4cfa8973..f6da52ebe30c 100644 --- a/src/relay/transforms/fast_math.cc +++ b/src/relay/transforms/fast_math.cc @@ -34,7 +34,11 @@ namespace relay { class FastMathMutator : public ExprRewriter { public: - FastMathMutator() : exp_op_(Op::Get("exp")), erf_op_(Op::Get("erf")), tanh_op_(Op::Get("tanh")) {} + FastMathMutator() + : exp_op_(Op::Get("exp")), + erf_op_(Op::Get("erf")), + tanh_op_(Op::Get("tanh")), + softmax_op_(Op::Get("nn.softmax")) {} Expr Rewrite_(const CallNode* pre, const Expr& post) override { if (pre->op == exp_op_) { @@ -43,6 +47,8 @@ class FastMathMutator : public ExprRewriter { return FastErf(post.as()->args[0]); } else if (pre->op == tanh_op_) { return FastTanh(post.as()->args[0]); + } else if (pre->op == softmax_op_) { + return FastSoftmax(post.as()->args[0], post.as()->attrs); } return post; } @@ -54,6 +60,7 @@ class FastMathMutator : public ExprRewriter { const Op& exp_op_; const Op& erf_op_; const Op& tanh_op_; + const Op& softmax_op_; }; Expr FastMath(const Expr& e) { diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index 50a695bf1d84..920ac153b63d 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -498,6 +498,11 @@ inline Expr FastTanh(Expr e) { return Call(op, {e}); } +inline Expr FastSoftmax(Expr e, tvm::Attrs attr) { + static const Op& op = Op::Get("nn.fast_softmax"); + return Call(op, {e}, attr); +} + inline Expr Log(Expr e) { static const Op& op = Op::Get("log"); return Call(op, {e}); diff --git a/tests/python/relay/test_op_fast_math.py b/tests/python/relay/test_op_fast_math.py index 8e401bc5670a..c9314fae37ac 100644 --- a/tests/python/relay/test_op_fast_math.py +++ b/tests/python/relay/test_op_fast_math.py @@ -27,7 +27,7 @@ def test_fastmath(): def test_apply(relay_op, name, f_numpy, low, high, step, dtype="float32"): - a_np = np.arange(low, high, step).astype(dtype) + a_np = np.arange(low, high, step).astype(dtype).reshape((1, -1)) b_np = f_numpy(a_np) x = relay.var("x", shape=a_np.shape, dtype="float32") @@ -56,6 +56,14 @@ def test_apply(relay_op, name, f_numpy, low, high, step, dtype="float32"): test_apply(relay.exp, "fast_exp", np.exp, low=-88, high=88, step=0.01) test_apply(relay.erf, "fast_erf", scipy.special.erf, low=-10, high=10, step=0.01) test_apply(relay.tanh, "fast_tanh", np.tanh, low=-10, high=10, step=0.01) + test_apply( + relay.nn.fast_softmax, + "nn_fast_softmax", + tvm.topi.testing.softmax_python, + low=-10, + high=10, + step=0.01, + ) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_fast_math.py b/tests/python/relay/test_pass_fast_math.py index bb3fb84fc61f..f63b6ce0f23e 100644 --- a/tests/python/relay/test_pass_fast_math.py +++ b/tests/python/relay/test_pass_fast_math.py @@ -65,7 +65,19 @@ def test_erf(): assert "fast_erf" in fast_mod[0].astext() +def test_softmax(): + x = relay.var("x", shape=(1, 16), dtype="float32") + y = relay.nn.softmax(x) + func = relay.Function([x], y) + mod = tvm.IRModule.from_expr(func) + + with tvm.transform.PassContext(opt_level=3, required_pass=["FastMath"]): + fast_mod = relay.optimize(mod, target="llvm") + assert "nn.fast_softmax" in fast_mod[0].astext() + + if __name__ == "__main__": test_exp() test_tanh() test_erf() + test_softmax()