Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Enhance Lower cross thread Pass #17133

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

LeiWang1999
Copy link
Contributor

We currently only support lower cross thread with several constrains. For example, the lower_cross_thread only apples when the thread binding reduced axis is the innermost loop, and the block must have an init block. This can be a limiting for some cases.

For example, when tensorizing the reduction block (e.g., dp4a or mma), it becomes difficult to tensorize the init statement as well:

with T.block("block"):
    vi = T.axis.spatial(2, i_0 * 16 + i_1)
    vk = T.axis.reduce(32, k_0 * 64 + k_1)
    T.where(i_0 * 16 + i_1 < 2 and k_0 * 64 + k_1 < 32)
    T.reads(A[vi, vk])
    T.writes(B[vi])
    with T.init():
        B[vi] = T.float32(0)
    B[vi] = B[vi] + A[vi, vk]

Moreover, certain cases, like small gemm, prefer block reduction in shared memory to enhance parallelization to better utilize the hardware resources.

This pull request improves the lower_cross_thread pass, it can now handle the thread block reduce lowering with separate init and reduce blocks, and removes the constrain that the reduced axis is the innermost loop to support TensorCore with block reduction.

relevant test cases can be found at tests/python/tir-transform/test_tir_transform_lower_cross_thread_reduction.py.

Please CC @MasterJH5574 .

@LeiWang1999 LeiWang1999 changed the title [TIR] Improve Lower cross thread Pass to handle block reduction [TIR] Improve Lower cross thread Pass Jul 3, 2024
@LeiWang1999 LeiWang1999 changed the title [TIR] Improve Lower cross thread Pass [TIR] Enhance Lower cross thread Pass Jul 3, 2024
@tqchen
Copy link
Member

tqchen commented Jul 31, 2024

@LeiWang1999 please fix the lint and test case, @wrongtest-intellif do you mind help review the PR

// output buffers.
if (!IsDominantBlock(scope_block, GetRef<Block>(block))) {
return false;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @LeiWang1999 could you explain why the dominant check here is cancelled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wrongtest-intellif Thanks for your review, I'll take a look and give you response this Friday.

src/tir/transforms/lower_cross_thread_reduction.cc Outdated Show resolved Hide resolved
Evaluate(Call(/*dtype=*/DataType::Handle(),
/*op=*/tir::builtin::tvm_thread_allreduce(),
/*args=*/std::move(parameters)))));
ObjectPtr<BlockNode> cross_thread_block_node =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we initialize kIsCrossThreadReductionApplied just in the constructor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah absolutely, but the problem is I don't know how to skip the initialization of init, alloc_buffers and match_buffers, give them empty values might be a bit ugly.

Block::Block(Array<IterVar> iter_vars, Array<BufferRegion> reads, Array<BufferRegion> writes,
             String name_hint, Stmt body, Optional<Stmt> init, Array<Buffer> alloc_buffers,
             Array<MatchBufferRegion> match_buffers, Map<String, ObjectRef> annotations,
             Span span)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also use block.CopyOnWrite()->annotations.Set(...) if full constructor params looks ugly.

src/tir/transforms/lower_cross_thread_reduction.cc Outdated Show resolved Hide resolved
if (buf_it == crt_buf2threads_.end()) {
continue;
}
for (auto[scope, range] : buf_it->second) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

requires clang-format

src/tir/transforms/lower_cross_thread_reduction.cc Outdated Show resolved Hide resolved
* \param block The block to be checked
* \return The init value of the given block
*/
static PrimExpr FindInit(const Block& block) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by the usage in context, maybe we could merge the FindInit() and CheckHasMMA() to visit the block only once.


private:
void VisitStmt_(const BufferStoreNode* node) final {
BufferStore store = GetRef<BufferStore>(node);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to check there is single buffer store and related value?

* \brief kind The kind of the for loop
* \return The loop variables between stmt1 and stmt2
*/
class LoopVar {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just use ForNode object?

@LeiWang1999
Copy link
Contributor Author

I’m attempting to remove theLoopVar with ForNode, and I’ve encountered an unexpected behavior.

For new_for = For(ax_lane_id, Integer(0), warp_size, ForKind::kThreadBinding, n->body);
const ForNode* new_for_node = new_for.get();
LOG(INFO) << "new_for->min " << new_for_node->min;
// Output: 0

This snippet works fine and logs the expected output of 0 for new_for->min.

However, when I try to instantiate new_for_node directly, like this:

const ForNode* new_for_node = For(ax_lane_id, Integer(0), warp_size, ForKind::kThreadBinding, n->body).get();
LOG(INFO) << "new_for->min " << new_for_node->min;
// Check failed: (tindex < type_table_.size() && type_table_[tindex].allocated_slots != 0) is false: Unknown type index 76391888

@wrongtest-intellif, do you have any thoughts?

@wrongtest-intellif
Copy link
Contributor

wrongtest-intellif commented Sep 24, 2024

However, when I try to instantiate new_for_node directly, like this:

node pointer do not take ownership of objects. it seems your new_for_node referenced object is already deconstructed in the previous line, with dangling pointer left.

@LeiWang1999
Copy link
Contributor Author

However, when I try to instantiate new_for_node directly, like this:

node pointer do not take ownership of objects. it seems your new_for_node referenced object is already deconstructed in the previous line, with dangling pointer left.

thanks, exactly, something relevant to the Ownership, now use For instead of ForNode* to take advantage of TVM ObjectRef management.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants