diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 572bc5c9a131..a54da695df24 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -115,8 +115,7 @@ class WarpStoreCoeffFinder : private StmtExprVisitor { /// Visitor implementation void VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::ptx_ldmatrix()) && op->args[3].as() == buffer_) { - int num_matrix = op->args[1].as()->value; - warp_coeff_ = num_matrix * 2; + UpdatePattern(op->args[4]); } else if (op->op.same_as(builtin::mma_fill()) && op->args[1].as() == buffer_) { auto* ptr = op->args[0].as(); CHECK(ptr); @@ -499,7 +498,7 @@ Pass LowerWarpMemory() { WarpMemoryRewriter warp_memory_rewriter(warp_size); auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body)); n->body = UpdatePointerStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt); - // LOG(INFO) << f; + LOG(INFO) << f; return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {});