Skip to content

Commit

Permalink
[Meta Schedule] Fix some bugs (#537)
Browse files Browse the repository at this point in the history
* fix

* fix

* fix

* fix
  • Loading branch information
spectrometerHBH authored Dec 7, 2021
1 parent d91b43f commit ce1e4b1
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/mutator/mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def f_initialize_with_tune_context(tune_context: "TuneContext") -> None:
self.initialize_with_tune_context(tune_context)

@check_override(self.__class__, Mutator)
def f_apply(trace: Trace) -> Optional[Trace]:
def f_apply(trace: Trace, _) -> Optional[Trace]:
return self.apply(trace)

def f_as_string() -> str:
Expand Down
12 changes: 6 additions & 6 deletions src/tir/transforms/lower_cross_thread_reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,22 +149,22 @@ Array<BufferRegion> RemoveBufferFromBufferRegions(const Array<BufferRegion>& buf
/*!
* \brief Substitute a given source buffer with a given target buffer in statements or expressions
*/
class BufferReplacer : private StmtExprMutator {
class BufferMutator : private StmtExprMutator {
public:
static Stmt Run(Buffer src_buffer, Buffer tgt_buffer, Stmt stmt) {
return BufferReplacer(src_buffer, tgt_buffer)(std::move(stmt));
return BufferMutator(src_buffer, tgt_buffer)(std::move(stmt));
}

private:
explicit BufferReplacer(Buffer src_buffer, Buffer tgt_buffer)
explicit BufferMutator(Buffer src_buffer, Buffer tgt_buffer)
: src_buffer_(std::move(src_buffer)), tgt_buffer_(std::move(tgt_buffer)) {}

PrimExpr VisitExpr_(const BufferLoadNode* load) final {
PrimExpr VisitExpr_(const BufferLoadNode* load) override {
return load->buffer.same_as(src_buffer_) ? BufferLoad(tgt_buffer_, {0})
: GetRef<BufferLoad>(load);
}

Stmt VisitStmt_(const BufferStoreNode* store) final {
Stmt VisitStmt_(const BufferStoreNode* store) override {
if (store->buffer.same_as(src_buffer_)) {
PrimExpr value = StmtExprMutator::VisitExpr(store->value);
return BufferStore(tgt_buffer_, value, {0});
Expand Down Expand Up @@ -287,7 +287,7 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, const Optional<Buf
new_block->writes = {it_buffer_region.value()};
new_block->name_hint = new_block->name_hint + "_in_thread";
new_block->body =
BufferReplacer::Run(wb_buffer, it_buffer.value(), std::move(new_block->body));
BufferMutator::Run(wb_buffer, it_buffer.value(), std::move(new_block->body));
new_block->init = NullOpt;
ObjectPtr<BlockRealizeNode> n = make_object<BlockRealizeNode>(*realize);
n->block = Block(new_block);
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_meta_schedule_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def test_meta_schedule_error_handle_time_out():

def initializer():
@register_func("meta_schedule.builder.test_time_out")
def timeout_build(mod, target): # pylint: disable=unused-argument, unused-variable
def timeout_build(mod, target, _): # pylint: disable=unused-argument, unused-variable
time.sleep(2)

builder = LocalBuilder(
Expand Down

0 comments on commit ce1e4b1

Please sign in to comment.