diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 0e71d4743b..b3d7b0865c 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -21,7 +21,6 @@ from typing import Union, Optional, List, Mapping import warnings -from tvm.ir import transform import tvm.tir @@ -186,6 +185,7 @@ def _build_for_device(input_mod, target, target_host): and f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH ), tvm.tir.transform.LowerWarpMemory(), + tvm.tir.transform.Simplify(), tvm.tir.transform.LowerDeviceStorageAccessInfo(), tvm.tir.transform.LowerCustomDatatypes(), tvm.tir.transform.LowerIntrin(), diff --git a/src/tir/transforms/lower_logical_intrin.cc b/src/tir/transforms/lower_logical_intrin.cc new file mode 100644 index 0000000000..1c885bfddc --- /dev/null +++ b/src/tir/transforms/lower_logical_intrin.cc @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief Lower logical intrinsics + * \file lower_logical_intrin.cc + */ +#include +#include +#include +#include +#include +#include +#include + + +namespace tvm { +namespace tir { + +struct LogicalIntrinRegistry { + static Map registry; +}; + +class LogicalIntrinBufferReplacer : public StmtExprMutator { + public: + explicit LogicalIntrinBufferReplacer(Map buffer_var_to_new_buffer) + : buffer_var_to_new_buffer_(std::move(buffer_var_to_new_buffer)){ + LOG(INFO) << buffer_var_to_new_buffer_.size(); + + } + + PrimExpr VisitExpr_(const VarNode* op) final { + auto it = buffer_var_to_new_buffer_.find(GetRef(op)); + if (it != buffer_var_to_new_buffer_.end()) { + return (*it).second->data; + } + return GetRef(op); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + auto it = buffer_var_to_new_buffer_.find(load->buffer->data); + if (it != buffer_var_to_new_buffer_.end()) { + auto *n = load.CopyOnWrite(); + n->buffer = (*it).second; + } + return load; + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + auto it = buffer_var_to_new_buffer_.find(store->buffer->data); + if (it != buffer_var_to_new_buffer_.end()) { + auto *n = store.CopyOnWrite(); + n->buffer = (*it).second; + } + return store; + } + + private: + Map buffer_var_to_new_buffer_; +}; + +class LogicalIntrinMutator : public StmtMutator { + public: + using FLowerLogicalIntrin = runtime::TypedPackedFunc; + + explicit LogicalIntrinMutator(const PrimFunc& func) { + for (const auto& kv : func->buffer_map) { + const Buffer& buffer = kv.second; + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + } + + Stmt VisitStmt_(const BlockNode* op) { + for (const auto& buffer : op->alloc_buffers) { + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + return StmtMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const EvaluateNode* op) { + static const auto& f_lower_logical_intrin = Op::GetAttrMap("FLowerLogicalIntrin"); + if (const auto* call = op->value.as()) { + if (const auto* call_op = call->op.as()) { + PrimFunc f = f_lower_logical_intrin.get(GetRef(call_op), NullValue()); + if (f.defined()) { + PrimFunc intrin_impl = f; + LOG(INFO) << GetRef(call); + LOG(INFO) << intrin_impl; + CHECK(intrin_impl->params.size() == call->args.size()); + Map subst_map; + for (size_t i = 0; i < call->args.size(); i++) { + // TODO type check + subst_map.Set(intrin_impl->params[i], call->args[i]); + } + Map new_buffer_map; + for (size_t i = 0; i < call->args.size(); i++) { + const auto& param = intrin_impl->params[i]; + if (const auto* var = param.as()) { + if (var->dtype.is_handle()) { + Var buffer_var = Downcast(param); + auto it =intrin_impl->buffer_map.find(buffer_var); + CHECK(it != intrin_impl->buffer_map.end()) << buffer_var; + if (it != intrin_impl->buffer_map.end()) { + // TODO check buffer match + new_buffer_map.Set((*it).second->data, buffer_data_to_buffer_.at(Downcast(call->args[i]))); + } + } + } + } + + auto body = Substitute(intrin_impl->body, subst_map); + return LogicalIntrinBufferReplacer(new_buffer_map)(body); + } + } + } + return StmtMutator::VisitStmt_(op); + } + + Map buffer_data_to_buffer_; +}; + +namespace transform { + +Pass LowerLogicalIntrin() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = LogicalIntrinMutator(f)(std::move(f->body)); + LOG(INFO) << "LowerLogicalIntrin Out:\n" << f; + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LowerLogicalLayout", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LowerLogicalIntrin").set_body_typed(LowerLogicalIntrin); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index b20c3d9249..cdf6030d20 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -209,8 +209,8 @@ class WarpIndexFinder : private StmtVisitor { // Mutator to change the read pattern class WarpAccessRewriter : protected StmtExprMutator { public: - explicit WarpAccessRewriter(int warp_size, Var warp_index, int width, arith::Analyzer* analyzer) - : warp_size_(warp_size), warp_index_(std::move(warp_index)), width_(width), analyzer_(analyzer) {} + explicit WarpAccessRewriter(int warp_size, arith::Analyzer* analyzer) + : warp_size_(warp_size), analyzer_(analyzer) {} // Rewrite the allocate statement which transforms // warp memory to local memory. Stmt Rewrite(const AllocateNode* op) { @@ -218,6 +218,7 @@ class WarpAccessRewriter : protected StmtExprMutator { int alloc_size = op->constant_allocation_size(); ICHECK_GT(alloc_size, 0) << "warp memory only support constant alloc size"; alloc_size *= op->dtype.lanes(); + std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body); warp_coeff_ = WarpStoreCoeffFinder(buffer_, warp_index_, analyzer_).Find(op->body); // Align the local memory size. The number of elements may not @@ -357,7 +358,6 @@ class WarpMemoryRewriter : private StmtMutator { if (warp_size_ == 1) return stmt; BindVarBoundInfo binder(&analyzer_); binder(stmt); - std::tie(warp_index_, warp_access_width_) = WarpIndexFinder(warp_size_).Find(stmt); stmt = operator()(std::move(stmt)); return stmt; } @@ -369,7 +369,7 @@ class WarpMemoryRewriter : private StmtMutator { auto ret = StmtMutator::VisitStmt_(op); op = ret.as(); if (warp_buffer_.count(op->buffer_var.get())) { - WarpAccessRewriter rewriter(warp_size_, warp_index_, warp_access_width_, &analyzer_); + WarpAccessRewriter rewriter(warp_size_, &analyzer_); ret = rewriter.Rewrite(op); } return ret; @@ -389,8 +389,6 @@ class WarpMemoryRewriter : private StmtMutator { } int warp_size_{0}; - int warp_access_width_{0}; - Var warp_index_{NullValue()}; std::unordered_set warp_buffer_; arith::Analyzer analyzer_; // variable domain