From a0c395f19700adf6023b7c4db2a2402dfd2381ff Mon Sep 17 00:00:00 2001 From: masahi Date: Tue, 7 Sep 2021 06:18:39 +0900 Subject: [PATCH] [TIR] Fixed LowerThreadallreduce not remapping Store buffer var (#8931) * Fixed LowerThreadallreduce not remapping Store buffer var * reenable warp reduction schedule for softmax with fused ops Co-authored-by: masa --- python/tvm/topi/cuda/softmax.py | 5 +- src/tir/transforms/lower_thread_allreduce.cc | 15 ++++ tests/python/integration/test_reduce.py | 86 +++++++++++++++++--- 3 files changed, 89 insertions(+), 17 deletions(-) diff --git a/python/tvm/topi/cuda/softmax.py b/python/tvm/topi/cuda/softmax.py index 79f804f6f8a4..14d2963acf98 100644 --- a/python/tvm/topi/cuda/softmax.py +++ b/python/tvm/topi/cuda/softmax.py @@ -75,10 +75,7 @@ def sched_warp_softmax(): for op in ops: s = schedule_injective_from_existing(s, op.output(0)) - elif sched_warp_softmax() and softmax_op == outs[0].op: - # TODO(masahi): Fix LowerThreadAllreduce pass to remove - # softmax_op == outs[0].op condition - + elif sched_warp_softmax(): # A warp of 32 threads performs a row reduction. num_thread = tgt.thread_warp_size block_x = te.thread_axis("blockIdx.x") diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 481b1bfd4b19..6f7c09cdcf2d 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -119,6 +119,17 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } } + Stmt VisitStmt_(const StoreNode* op) final { + auto it = store_remap_.find(op->buffer_var.get()); + if (it != store_remap_.end()) { + ICHECK(is_zero(op->index)); + auto value = StmtExprMutator::VisitExpr(op->value); + return Store(it->second, value, 0, op->predicate); + } else { + return StmtExprMutator::VisitStmt_(op); + } + } + std::unordered_map new_storage_scopes_; private: @@ -328,6 +339,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { PrimExpr pred = const_true(types[i].lanes()); Var var = shared_bufs[i]; load_remap_[buffers[i]] = Load(types[i], var, index, pred); + store_remap_[buffers[i]] = var; Array extents{PrimExpr(1)}; auto node = Allocate(var, types[i], extents, pred, Evaluate(0)); alloc_remap_[buffers[i]] = node; @@ -370,6 +382,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { alloc_remap_[buffers[idx]] = Allocate(shared_bufs[idx], types[idx], {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, Evaluate(0)); + store_remap_[buffers[idx]] = shared_bufs[idx]; } } @@ -587,6 +600,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { std::vector reduce_combiner_; // The load remap std::unordered_map load_remap_; + // The store remap + std::unordered_map store_remap_; // Allocate remap std::unordered_map alloc_remap_; // Allocate from warp reductions diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py index 939d0819546b..ca097734a9eb 100644 --- a/tests/python/integration/test_reduce.py +++ b/tests/python/integration/test_reduce.py @@ -14,10 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest + import tvm -from tvm import te +from tvm import te, topi import numpy as np import tvm.testing +import tvm.topi.testing @tvm.testing.requires_gpu @@ -524,16 +527,73 @@ def check_target(device): check_target("rocm") +@tvm.testing.requires_gpu +def test_reduce_storage_reuse(): + target = tvm.target.Target("cuda") + + def run_passes(sch, args): + bounds = tvm.te.schedule.InferBound(sch) + stmt = tvm.te.schedule.ScheduleOps(sch, bounds) + func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod) + return tvm.transform.Sequential( + [ + tvm.tir.transform.StorageFlatten(64), + tvm.tir.transform.Simplify(), + tvm.tir.transform.StorageRewrite(), + tvm.tir.transform.LowerThreadAllreduce(), + ] + )(mod) + + dev = tvm.device(target.kind.name, 0) + shape = (16, 16) + + A = te.placeholder(shape, dtype="float32", name="A") + B = topi.nn.softmax(A, axis=1) + 1.0 + + with tvm.target.Target(target): + s = topi.cuda.schedule_softmax(B) + + mod = run_passes(s, [A, B]) + + # Due to the storage rewrite pass, the reduction output storage reduce_temp0 can be reused as + # the storage of the next compute. + + # Example: + # ... + # tir.tvm_thread_allreduce((uint32)1, normal_reduce_temp0[0], 1, reduce_temp0, threadIdx.x) + # if ((threadIdx.x < 16)) { + # reduce_temp0[0] = (T_softmax_exp[threadIdx.x]/reduce_temp0[0]) + # } + # ... + + # The LowerThreadAllreduce pass should remap reduce_temp0 on the left hand side of the store + # above, as well as the load on the right hand side. + + # Expected output: + # ... + # red_buf0[0] = tir.tvm_warp_shuffle(mask[0], red_buf0[0], 0, 32, 32) + # if ((threadIdx.x < 16)) { + # red_buf0[0] = (T_softmax_exp[threadIdx.x]/red_buf0[0]) + # } + # ... + + def check_store_dst_remapped(op): + if isinstance(op, tvm.tir.Store): + assert op.buffer_var.name != "reduce_temp0" + + tvm.tir.stmt_functor.post_order_visit(mod["main"].body, check_store_dst_remapped) + + inp = np.random.uniform(size=shape).astype("float32") + ref = tvm.topi.testing.softmax_python(inp) + 1.0 + + f = tvm.build(s, [A, B], target) + a = tvm.nd.array(inp, dev) + b = tvm.nd.array(np.zeros(shape, dtype=B.dtype), dev) + f(a, b) + tvm.testing.assert_allclose(b.numpy(), ref, rtol=1e-5) + + if __name__ == "__main__": - test_rfactor_elemwise_threads() - test_rfactor_threads() - test_rfactor_factor_axis() - test_rfactor() - test_reduce_prims() - test_argmax() - test_rfactor_argmax() - test_warp_reduction1() - test_warp_reduction2() - test_init() - test_init_imm() - test_rfactor_init() + pytest.main([__pfile__])