Skip to content

Commit

Permalink
lower intrin pass
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Aug 31, 2021
1 parent 6f5b3ed commit 228f3f7
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 7 deletions.
2 changes: 1 addition & 1 deletion python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from typing import Union, Optional, List, Mapping
import warnings
from tvm.ir import transform

import tvm.tir

Expand Down Expand Up @@ -186,6 +185,7 @@ def _build_for_device(input_mod, target, target_host):
and f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH
),
tvm.tir.transform.LowerWarpMemory(),
tvm.tir.transform.Simplify(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerCustomDatatypes(),
tvm.tir.transform.LowerIntrin(),
Expand Down
157 changes: 157 additions & 0 deletions src/tir/transforms/lower_logical_intrin.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \brief Lower logical intrinsics
* \file lower_logical_intrin.cc
*/
#include <tvm/arith/iter_affine_map.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/target.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>


namespace tvm {
namespace tir {

struct LogicalIntrinRegistry {
static Map<String, PrimFunc> registry;
};

class LogicalIntrinBufferReplacer : public StmtExprMutator {
public:
explicit LogicalIntrinBufferReplacer(Map<Var, Buffer> buffer_var_to_new_buffer)
: buffer_var_to_new_buffer_(std::move(buffer_var_to_new_buffer)){
LOG(INFO) << buffer_var_to_new_buffer_.size();

}

PrimExpr VisitExpr_(const VarNode* op) final {
auto it = buffer_var_to_new_buffer_.find(GetRef<Var>(op));
if (it != buffer_var_to_new_buffer_.end()) {
return (*it).second->data;
}
return GetRef<Var>(op);
}

PrimExpr VisitExpr_(const BufferLoadNode* op) final {
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
auto it = buffer_var_to_new_buffer_.find(load->buffer->data);
if (it != buffer_var_to_new_buffer_.end()) {
auto *n = load.CopyOnWrite();
n->buffer = (*it).second;
}
return load;
}

Stmt VisitStmt_(const BufferStoreNode* op) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
auto it = buffer_var_to_new_buffer_.find(store->buffer->data);
if (it != buffer_var_to_new_buffer_.end()) {
auto *n = store.CopyOnWrite();
n->buffer = (*it).second;
}
return store;
}

private:
Map<Var, Buffer> buffer_var_to_new_buffer_;
};

class LogicalIntrinMutator : public StmtMutator {
public:
using FLowerLogicalIntrin = runtime::TypedPackedFunc<PrimFunc(PrimExpr)>;

explicit LogicalIntrinMutator(const PrimFunc& func) {
for (const auto& kv : func->buffer_map) {
const Buffer& buffer = kv.second;
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
}

Stmt VisitStmt_(const BlockNode* op) {
for (const auto& buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
return StmtMutator::VisitStmt_(op);
}

Stmt VisitStmt_(const EvaluateNode* op) {
static const auto& f_lower_logical_intrin = Op::GetAttrMap<PrimFunc>("FLowerLogicalIntrin");
if (const auto* call = op->value.as<CallNode>()) {
if (const auto* call_op = call->op.as<OpNode>()) {
PrimFunc f = f_lower_logical_intrin.get(GetRef<Op>(call_op), NullValue<PrimFunc>());
if (f.defined()) {
PrimFunc intrin_impl = f;
LOG(INFO) << GetRef<Call>(call);
LOG(INFO) << intrin_impl;
CHECK(intrin_impl->params.size() == call->args.size());
Map<Var, PrimExpr> subst_map;
for (size_t i = 0; i < call->args.size(); i++) {
// TODO type check
subst_map.Set(intrin_impl->params[i], call->args[i]);
}
Map<Var, Buffer> new_buffer_map;
for (size_t i = 0; i < call->args.size(); i++) {
const auto& param = intrin_impl->params[i];
if (const auto* var = param.as<VarNode>()) {
if (var->dtype.is_handle()) {
Var buffer_var = Downcast<Var>(param);
auto it =intrin_impl->buffer_map.find(buffer_var);
CHECK(it != intrin_impl->buffer_map.end()) << buffer_var;
if (it != intrin_impl->buffer_map.end()) {
// TODO check buffer match
new_buffer_map.Set((*it).second->data, buffer_data_to_buffer_.at(Downcast<Var>(call->args[i])));
}
}
}
}

auto body = Substitute(intrin_impl->body, subst_map);
return LogicalIntrinBufferReplacer(new_buffer_map)(body);
}
}
}
return StmtMutator::VisitStmt_(op);
}

Map<Var, Buffer> buffer_data_to_buffer_;
};

namespace transform {

Pass LowerLogicalIntrin() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = LogicalIntrinMutator(f)(std::move(f->body));
LOG(INFO) << "LowerLogicalIntrin Out:\n" << f;
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerLogicalLayout", {});
}

TVM_REGISTER_GLOBAL("tir.transform.LowerLogicalIntrin").set_body_typed(LowerLogicalIntrin);

} // namespace transform

} // namespace tir
} // namespace tvm
10 changes: 4 additions & 6 deletions src/tir/transforms/lower_warp_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,16 @@ class WarpIndexFinder : private StmtVisitor {
// Mutator to change the read pattern
class WarpAccessRewriter : protected StmtExprMutator {
public:
explicit WarpAccessRewriter(int warp_size, Var warp_index, int width, arith::Analyzer* analyzer)
: warp_size_(warp_size), warp_index_(std::move(warp_index)), width_(width), analyzer_(analyzer) {}
explicit WarpAccessRewriter(int warp_size, arith::Analyzer* analyzer)
: warp_size_(warp_size), analyzer_(analyzer) {}
// Rewrite the allocate statement which transforms
// warp memory to local memory.
Stmt Rewrite(const AllocateNode* op) {
buffer_ = op->buffer_var.get();
int alloc_size = op->constant_allocation_size();
ICHECK_GT(alloc_size, 0) << "warp memory only support constant alloc size";
alloc_size *= op->dtype.lanes();
std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body);
warp_coeff_ = WarpStoreCoeffFinder(buffer_, warp_index_, analyzer_).Find(op->body);

// Align the local memory size. The number of elements may not
Expand Down Expand Up @@ -357,7 +358,6 @@ class WarpMemoryRewriter : private StmtMutator {
if (warp_size_ == 1) return stmt;
BindVarBoundInfo binder(&analyzer_);
binder(stmt);
std::tie(warp_index_, warp_access_width_) = WarpIndexFinder(warp_size_).Find(stmt);
stmt = operator()(std::move(stmt));
return stmt;
}
Expand All @@ -369,7 +369,7 @@ class WarpMemoryRewriter : private StmtMutator {
auto ret = StmtMutator::VisitStmt_(op);
op = ret.as<AllocateNode>();
if (warp_buffer_.count(op->buffer_var.get())) {
WarpAccessRewriter rewriter(warp_size_, warp_index_, warp_access_width_, &analyzer_);
WarpAccessRewriter rewriter(warp_size_, &analyzer_);
ret = rewriter.Rewrite(op);
}
return ret;
Expand All @@ -389,8 +389,6 @@ class WarpMemoryRewriter : private StmtMutator {
}

int warp_size_{0};
int warp_access_width_{0};
Var warp_index_{NullValue<Var>()};
std::unordered_set<const VarNode*> warp_buffer_;
arith::Analyzer analyzer_;
// variable domain
Expand Down

0 comments on commit 228f3f7

Please sign in to comment.