Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Finer predicate handling in cross-thread reduction #15374

Merged

Conversation

MasterJH5574
Copy link
Contributor

This PR fixes the predicate handling logic of the cross-thread reduction lowering pass.

For the cross-thread reduction write-back block, prior to this PR, its predicate is the conjunction of t == 0 for each reduction thread dim of the cross-thread reduction. This is problematic when the write-back buffer is stored in local memory, where each thread is supposed to have a copy of the final value, while the final value is only stored by the first thread. In this PR, the predicate is changed to be the conjunction of the clauses from the two parts:

  • the clause of the original reduction block's predicate which contains spatial loop var,
  • t == 0 for each reduction thread dim only when the write-back buffer is global or shared.

So the first part ensures that the write-back will not go out of bound, and the second part ensures that when the write-back buffer is local, every thread gets a value and when the write-back buffer is non-local, only one thread writes the value out.

Meanwhile, this PR fixes the cross-thread broadcasting detection with the awareness of the storage scope of the write buffer of the broadcasting block. Specifically, for each consumer block of a buffer produced by cross-thread reduction under the same kernel (i.e., same set of blockIdx) of the cross-thread reduction block, when the write buffer of this consumer block is in local memory, we do not treat it as broadcasting, and will not add a predicate to it. Otherwise, we will add the predicate according to the broadcasting handling introduced by #15192.

@tvm-bot
Copy link
Collaborator

tvm-bot commented Jul 21, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@MasterJH5574
Copy link
Contributor Author

Depending on #15373 to make sure that each thread can access to the reduction results after allreduce.

@MasterJH5574
Copy link
Contributor Author

cc @yzh119 @tqchen

@MasterJH5574 MasterJH5574 force-pushed the tvm-dev/2023-07-20-crt-write-back-pred branch 3 times, most recently from bfa3c00 to b6e3a83 Compare July 21, 2023 16:25
This PR fixes the predicate handling logic of the cross-thread
reduction lowering pass.

For the cross-thread reduction write-back block, prior to this PR, its
predicate is the conjunction of `t == 0` for each reduction thread dim
of the cross-thread reduction. This is problematic when the write-back
buffer is stored in local memory, where each thread is supposed to
have a copy of the final value, while the final value is only stored
by the first thread. In this PR, the predicate is changed to be the
conjunction of the clauses from the two parts:

* the clause of the original reduction block's predicate which contains
spatial loop var,
* `t == 0` for each reduction thread dim **only when the write-back
buffer is global or shared**.

So the first part ensures that the write-back will not go out of bound,
and the second part ensures that when the write-back buffer is local,
every thread gets a value and when the write-back buffer is non-local,
only one thread writes the value out.

Meanwhile, this PR fixes the cross-thread broadcasting detection with
the awareness of the storage scope of the write buffer of the
broadcasting block. Specifically, for each consumer block of a buffer
produced by cross-thread reduction under the same kernel (i.e., same
set of `blockIdx`) of the cross-thread reduction block, when the
write buffer of this consumer block is in local memory, we do not treat
it as broadcasting, and will not add a predicate to it. Otherwise,
we will add the predicate according to the broadcasting handling
introduced by apache#15192.
@MasterJH5574 MasterJH5574 force-pushed the tvm-dev/2023-07-20-crt-write-back-pred branch from b6e3a83 to 8a2dd01 Compare July 22, 2023 00:01
@tqchen tqchen merged commit 3f69ed4 into apache:main Jul 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants