Skip to content

Commit

Permalink
[TIR] Fixed LowerThreadallreduce not remapping Store buffer var (#8931)
Browse files Browse the repository at this point in the history
* Fixed LowerThreadallreduce not remapping Store buffer var

* reenable warp reduction schedule for softmax with fused ops

Co-authored-by: masa <masa@pop-os.localdomain>
  • Loading branch information
masahi and masa authored Sep 6, 2021
1 parent 22dbc3a commit ab0f055
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 17 deletions.
5 changes: 1 addition & 4 deletions python/tvm/topi/cuda/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,7 @@ def sched_warp_softmax():
for op in ops:
s = schedule_injective_from_existing(s, op.output(0))

elif sched_warp_softmax() and softmax_op == outs[0].op:
# TODO(masahi): Fix LowerThreadAllreduce pass to remove
# softmax_op == outs[0].op condition

elif sched_warp_softmax():
# A warp of 32 threads performs a row reduction.
num_thread = tgt.thread_warp_size
block_x = te.thread_axis("blockIdx.x")
Expand Down
15 changes: 15 additions & 0 deletions src/tir/transforms/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,17 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
}
}

Stmt VisitStmt_(const StoreNode* op) final {
auto it = store_remap_.find(op->buffer_var.get());
if (it != store_remap_.end()) {
ICHECK(is_zero(op->index));
auto value = StmtExprMutator::VisitExpr(op->value);
return Store(it->second, value, 0, op->predicate);
} else {
return StmtExprMutator::VisitStmt_(op);
}
}

std::unordered_map<const VarNode*, String> new_storage_scopes_;

private:
Expand Down Expand Up @@ -328,6 +339,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
PrimExpr pred = const_true(types[i].lanes());
Var var = shared_bufs[i];
load_remap_[buffers[i]] = Load(types[i], var, index, pred);
store_remap_[buffers[i]] = var;
Array<PrimExpr> extents{PrimExpr(1)};
auto node = Allocate(var, types[i], extents, pred, Evaluate(0));
alloc_remap_[buffers[i]] = node;
Expand Down Expand Up @@ -370,6 +382,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
alloc_remap_[buffers[idx]] =
Allocate(shared_bufs[idx], types[idx],
{PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, Evaluate(0));
store_remap_[buffers[idx]] = shared_bufs[idx];
}
}

Expand Down Expand Up @@ -587,6 +600,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
std::vector<const CommReducerNode*> reduce_combiner_;
// The load remap
std::unordered_map<const VarNode*, PrimExpr> load_remap_;
// The store remap
std::unordered_map<const VarNode*, Var> store_remap_;
// Allocate remap
std::unordered_map<const VarNode*, Stmt> alloc_remap_;
// Allocate from warp reductions
Expand Down
86 changes: 73 additions & 13 deletions tests/python/integration/test_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest

import tvm
from tvm import te
from tvm import te, topi
import numpy as np
import tvm.testing
import tvm.topi.testing


@tvm.testing.requires_gpu
Expand Down Expand Up @@ -524,16 +527,73 @@ def check_target(device):
check_target("rocm")


@tvm.testing.requires_gpu
def test_reduce_storage_reuse():
target = tvm.target.Target("cuda")

def run_passes(sch, args):
bounds = tvm.te.schedule.InferBound(sch)
stmt = tvm.te.schedule.ScheduleOps(sch, bounds)
func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None)
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod)
return tvm.transform.Sequential(
[
tvm.tir.transform.StorageFlatten(64),
tvm.tir.transform.Simplify(),
tvm.tir.transform.StorageRewrite(),
tvm.tir.transform.LowerThreadAllreduce(),
]
)(mod)

dev = tvm.device(target.kind.name, 0)
shape = (16, 16)

A = te.placeholder(shape, dtype="float32", name="A")
B = topi.nn.softmax(A, axis=1) + 1.0

with tvm.target.Target(target):
s = topi.cuda.schedule_softmax(B)

mod = run_passes(s, [A, B])

# Due to the storage rewrite pass, the reduction output storage reduce_temp0 can be reused as
# the storage of the next compute.

# Example:
# ...
# tir.tvm_thread_allreduce((uint32)1, normal_reduce_temp0[0], 1, reduce_temp0, threadIdx.x)
# if ((threadIdx.x < 16)) {
# reduce_temp0[0] = (T_softmax_exp[threadIdx.x]/reduce_temp0[0])
# }
# ...

# The LowerThreadAllreduce pass should remap reduce_temp0 on the left hand side of the store
# above, as well as the load on the right hand side.

# Expected output:
# ...
# red_buf0[0] = tir.tvm_warp_shuffle(mask[0], red_buf0[0], 0, 32, 32)
# if ((threadIdx.x < 16)) {
# red_buf0[0] = (T_softmax_exp[threadIdx.x]/red_buf0[0])
# }
# ...

def check_store_dst_remapped(op):
if isinstance(op, tvm.tir.Store):
assert op.buffer_var.name != "reduce_temp0"

tvm.tir.stmt_functor.post_order_visit(mod["main"].body, check_store_dst_remapped)

inp = np.random.uniform(size=shape).astype("float32")
ref = tvm.topi.testing.softmax_python(inp) + 1.0

f = tvm.build(s, [A, B], target)
a = tvm.nd.array(inp, dev)
b = tvm.nd.array(np.zeros(shape, dtype=B.dtype), dev)
f(a, b)
tvm.testing.assert_allclose(b.numpy(), ref, rtol=1e-5)


if __name__ == "__main__":
test_rfactor_elemwise_threads()
test_rfactor_threads()
test_rfactor_factor_axis()
test_rfactor()
test_reduce_prims()
test_argmax()
test_rfactor_argmax()
test_warp_reduction1()
test_warp_reduction2()
test_init()
test_init_imm()
test_rfactor_init()
pytest.main([__pfile__])

0 comments on commit ab0f055

Please sign in to comment.