[TIR] Finer predicate handling in cross-thread reduction #15374
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR fixes the predicate handling logic of the cross-thread reduction lowering pass.
For the cross-thread reduction write-back block, prior to this PR, its predicate is the conjunction of
t == 0
for each reduction thread dim of the cross-thread reduction. This is problematic when the write-back buffer is stored in local memory, where each thread is supposed to have a copy of the final value, while the final value is only stored by the first thread. In this PR, the predicate is changed to be the conjunction of the clauses from the two parts:t == 0
for each reduction thread dim only when the write-back buffer is global or shared.So the first part ensures that the write-back will not go out of bound, and the second part ensures that when the write-back buffer is local, every thread gets a value and when the write-back buffer is non-local, only one thread writes the value out.
Meanwhile, this PR fixes the cross-thread broadcasting detection with the awareness of the storage scope of the write buffer of the broadcasting block. Specifically, for each consumer block of a buffer produced by cross-thread reduction under the same kernel (i.e., same set of
blockIdx
) of the cross-thread reduction block, when the write buffer of this consumer block is in local memory, we do not treat it as broadcasting, and will not add a predicate to it. Otherwise, we will add the predicate according to the broadcasting handling introduced by #15192.