Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Meta Schedule] Fix some bugs #537

Merged
merged 4 commits into from
Dec 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/mutator/mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def f_initialize_with_tune_context(tune_context: "TuneContext") -> None:
self.initialize_with_tune_context(tune_context)

@check_override(self.__class__, Mutator)
def f_apply(trace: Trace) -> Optional[Trace]:
def f_apply(trace: Trace, _) -> Optional[Trace]:
return self.apply(trace)

def f_as_string() -> str:
Expand Down
12 changes: 6 additions & 6 deletions src/tir/transforms/lower_cross_thread_reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,22 +149,22 @@ Array<BufferRegion> RemoveBufferFromBufferRegions(const Array<BufferRegion>& buf
/*!
* \brief Substitute a given source buffer with a given target buffer in statements or expressions
*/
class BufferReplacer : private StmtExprMutator {
class BufferMutator : private StmtExprMutator {
public:
static Stmt Run(Buffer src_buffer, Buffer tgt_buffer, Stmt stmt) {
return BufferReplacer(src_buffer, tgt_buffer)(std::move(stmt));
return BufferMutator(src_buffer, tgt_buffer)(std::move(stmt));
}

private:
explicit BufferReplacer(Buffer src_buffer, Buffer tgt_buffer)
explicit BufferMutator(Buffer src_buffer, Buffer tgt_buffer)
: src_buffer_(std::move(src_buffer)), tgt_buffer_(std::move(tgt_buffer)) {}

PrimExpr VisitExpr_(const BufferLoadNode* load) final {
PrimExpr VisitExpr_(const BufferLoadNode* load) override {
return load->buffer.same_as(src_buffer_) ? BufferLoad(tgt_buffer_, {0})
: GetRef<BufferLoad>(load);
}

Stmt VisitStmt_(const BufferStoreNode* store) final {
Stmt VisitStmt_(const BufferStoreNode* store) override {
if (store->buffer.same_as(src_buffer_)) {
PrimExpr value = StmtExprMutator::VisitExpr(store->value);
return BufferStore(tgt_buffer_, value, {0});
Expand Down Expand Up @@ -287,7 +287,7 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, const Optional<Buf
new_block->writes = {it_buffer_region.value()};
new_block->name_hint = new_block->name_hint + "_in_thread";
new_block->body =
BufferReplacer::Run(wb_buffer, it_buffer.value(), std::move(new_block->body));
BufferMutator::Run(wb_buffer, it_buffer.value(), std::move(new_block->body));
new_block->init = NullOpt;
ObjectPtr<BlockRealizeNode> n = make_object<BlockRealizeNode>(*realize);
n->block = Block(new_block);
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_meta_schedule_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def test_meta_schedule_error_handle_time_out():

def initializer():
@register_func("meta_schedule.builder.test_time_out")
def timeout_build(mod, target): # pylint: disable=unused-argument, unused-variable
def timeout_build(mod, target, _): # pylint: disable=unused-argument, unused-variable
time.sleep(2)

builder = LocalBuilder(
Expand Down