From ce1e4b1afd0f074a29de2aa8d79a90fd4fe40b61 Mon Sep 17 00:00:00 2001 From: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Date: Mon, 6 Dec 2021 23:28:52 -0500 Subject: [PATCH] [Meta Schedule] Fix some bugs (#537) * fix * fix * fix * fix --- python/tvm/meta_schedule/mutator/mutator.py | 2 +- src/tir/transforms/lower_cross_thread_reduction.cc | 12 ++++++------ tests/python/unittest/test_meta_schedule_builder.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/python/tvm/meta_schedule/mutator/mutator.py b/python/tvm/meta_schedule/mutator/mutator.py index ec6f6c0568..d3b0085911 100644 --- a/python/tvm/meta_schedule/mutator/mutator.py +++ b/python/tvm/meta_schedule/mutator/mutator.py @@ -71,7 +71,7 @@ def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: self.initialize_with_tune_context(tune_context) @check_override(self.__class__, Mutator) - def f_apply(trace: Trace) -> Optional[Trace]: + def f_apply(trace: Trace, _) -> Optional[Trace]: return self.apply(trace) def f_as_string() -> str: diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index 630c00f8c1..aa811b49d7 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -149,22 +149,22 @@ Array RemoveBufferFromBufferRegions(const Array& buf /*! * \brief Substitute a given source buffer with a given target buffer in statements or expressions */ -class BufferReplacer : private StmtExprMutator { +class BufferMutator : private StmtExprMutator { public: static Stmt Run(Buffer src_buffer, Buffer tgt_buffer, Stmt stmt) { - return BufferReplacer(src_buffer, tgt_buffer)(std::move(stmt)); + return BufferMutator(src_buffer, tgt_buffer)(std::move(stmt)); } private: - explicit BufferReplacer(Buffer src_buffer, Buffer tgt_buffer) + explicit BufferMutator(Buffer src_buffer, Buffer tgt_buffer) : src_buffer_(std::move(src_buffer)), tgt_buffer_(std::move(tgt_buffer)) {} - PrimExpr VisitExpr_(const BufferLoadNode* load) final { + PrimExpr VisitExpr_(const BufferLoadNode* load) override { return load->buffer.same_as(src_buffer_) ? BufferLoad(tgt_buffer_, {0}) : GetRef(load); } - Stmt VisitStmt_(const BufferStoreNode* store) final { + Stmt VisitStmt_(const BufferStoreNode* store) override { if (store->buffer.same_as(src_buffer_)) { PrimExpr value = StmtExprMutator::VisitExpr(store->value); return BufferStore(tgt_buffer_, value, {0}); @@ -287,7 +287,7 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, const Optionalwrites = {it_buffer_region.value()}; new_block->name_hint = new_block->name_hint + "_in_thread"; new_block->body = - BufferReplacer::Run(wb_buffer, it_buffer.value(), std::move(new_block->body)); + BufferMutator::Run(wb_buffer, it_buffer.value(), std::move(new_block->body)); new_block->init = NullOpt; ObjectPtr n = make_object(*realize); n->block = Block(new_block); diff --git a/tests/python/unittest/test_meta_schedule_builder.py b/tests/python/unittest/test_meta_schedule_builder.py index fb3fa135a9..03476ddefa 100644 --- a/tests/python/unittest/test_meta_schedule_builder.py +++ b/tests/python/unittest/test_meta_schedule_builder.py @@ -201,7 +201,7 @@ def test_meta_schedule_error_handle_time_out(): def initializer(): @register_func("meta_schedule.builder.test_time_out") - def timeout_build(mod, target): # pylint: disable=unused-argument, unused-variable + def timeout_build(mod, target, _): # pylint: disable=unused-argument, unused-variable time.sleep(2) builder = LocalBuilder(