Skip to content

Commit

Permalink
Lower logical intrin and end-to-end demo (#448)
Browse files Browse the repository at this point in the history
* [WIP] Logical Layout lowering

* add intrin

* Logical inntrin lowering

* e2e demo

* LowerLogicalIntrin

* Remove num groups

* remove old demo

* rebase

* fix

* lower intrin pass
  • Loading branch information
vinx13 authored Sep 1, 2021
1 parent 6de9b49 commit ca726d7
Show file tree
Hide file tree
Showing 9 changed files with 528 additions and 8 deletions.
5 changes: 5 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,11 @@ TVM_DLL const Op& atomic_add();
*/
TVM_DLL const Op& tvm_memcpy_async();

/*!
* \brief tvm intrinsic for mfma instruction
*/
TVM_DLL const Op& tvm_mfma_sync();

/*! \brief The kind of structure field info used in intrinsic */
enum TVMStructFieldKind : int {
// array head address
Expand Down
7 changes: 6 additions & 1 deletion include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -534,11 +534,16 @@ TVM_DLL Pass FlattenBuffer();
TVM_DLL Pass UnifyThreadBinding();

/*!
* \brief Lower lower logical layout into physical layout.
* \brief Lower logical layout into physical layout.
* \return The pass.
*/
TVM_DLL Pass LowerLogicalLayout();

/*!
* \brief Lower logical intrinsics into physical intrinsics.
*/
TVM_DLL Pass LowerLogicalIntrin();

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
12 changes: 7 additions & 5 deletions python/tvm/script/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,13 @@ def check_index(index: Union[int, PrimExpr]):
if index < 0:
report_error("Negative index is not allowed during buffer access", span)
elif isinstance(index, PrimExpr):
if index.dtype != "int32":
report_error(
"index expected an int32 type PrimExpr but got " + str(index.dtype),
index.span,
)
# FIXME(vinx13): Ramp is allowed when registering logical intrinsic implementatoins
# if index.dtype != "int32":
# report_error(
# "index expected an int32 type PrimExpr but got " + str(index.dtype),
# index.span,
# )
pass
else:
report_error(
"Unsupported index type, expected int or tvm.tir.PrimExpr, but got "
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,3 +812,14 @@ def LowerLogicalLayout():
The result pass
"""
return _ffi_api.LowerLogicalLayout()


def LowerLogicalIntrin():
"""Lower logical intrinsics to physical intrinsics.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerLogicalIntrin()
3 changes: 3 additions & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool for
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::LowerLogicalLayout());
pass_list.push_back(tir::transform::LowerLogicalIntrin());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::UnifyThreadBinding());
}
Expand Down
3 changes: 3 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_pipeline_consumer_wait)
TIR_DEFINE_BUILTIN_FUNC(tvm_pipeline_consumer_release)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(tvm_mfma_sync)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

} // namespace builtin
} // namespace tir
} // namespace tvm
18 changes: 16 additions & 2 deletions src/tir/transforms/flatten_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,13 @@ class BufferFlattener : public StmtExprMutator {

Stmt VisitStmt_(const BufferStoreNode* op) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
return store->buffer.vstore(store->indices, store->value);
Array<PrimExpr> indices = store->indices;
if (indices.size()) {
if (const auto* ramp = indices.back().as<RampNode>()) {
indices.Set(indices.size() - 1, ramp->base);
}
}
return store->buffer.vstore(indices, store->value);
}

PrimExpr VisitExpr_(const VarNode* op) final {
Expand All @@ -127,7 +133,15 @@ class BufferFlattener : public StmtExprMutator {

PrimExpr VisitExpr_(const BufferLoadNode* op) final {
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
return load->buffer.vload(load->indices, load->dtype);
Array<PrimExpr> indices = load->indices;
DataType dtype = load->dtype;
if (indices.size()) {
if (const auto* ramp = indices.back().as<RampNode>()) {
dtype = dtype.with_lanes(ramp->lanes);
indices.Set(indices.size() - 1, ramp->base);
}
}
return load->buffer.vload(indices, dtype);
}

PrimExpr VisitExpr_(const CallNode* op) final {
Expand Down
150 changes: 150 additions & 0 deletions src/tir/transforms/lower_logical_intrin.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \brief Lower logical intrinsics
* \file lower_logical_intrin.cc
*/
#include <tvm/arith/iter_affine_map.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/target.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

namespace tvm {
namespace tir {

struct LogicalIntrinRegistry {
static Map<String, PrimFunc> registry;
};

class LogicalIntrinBufferReplacer : public StmtExprMutator {
public:
explicit LogicalIntrinBufferReplacer(Map<Var, Buffer> buffer_var_to_new_buffer)
: buffer_var_to_new_buffer_(std::move(buffer_var_to_new_buffer)) {
}

PrimExpr VisitExpr_(const VarNode* op) final {
auto it = buffer_var_to_new_buffer_.find(GetRef<Var>(op));
if (it != buffer_var_to_new_buffer_.end()) {
return (*it).second->data;
}
return GetRef<Var>(op);
}

PrimExpr VisitExpr_(const BufferLoadNode* op) final {
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
auto it = buffer_var_to_new_buffer_.find(load->buffer->data);
if (it != buffer_var_to_new_buffer_.end()) {
auto* n = load.CopyOnWrite();
n->buffer = (*it).second;
}
return load;
}

Stmt VisitStmt_(const BufferStoreNode* op) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
auto it = buffer_var_to_new_buffer_.find(store->buffer->data);
if (it != buffer_var_to_new_buffer_.end()) {
auto* n = store.CopyOnWrite();
n->buffer = (*it).second;
}
return store;
}

private:
Map<Var, Buffer> buffer_var_to_new_buffer_;
};

class LogicalIntrinMutator : public StmtMutator {
public:
using FLowerLogicalIntrin = runtime::TypedPackedFunc<PrimFunc(PrimExpr)>;

explicit LogicalIntrinMutator(const PrimFunc& func) {
for (const auto& kv : func->buffer_map) {
const Buffer& buffer = kv.second;
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
}

Stmt VisitStmt_(const BlockNode* op) {
for (const auto& buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
return StmtMutator::VisitStmt_(op);
}

Stmt VisitStmt_(const EvaluateNode* op) {
static const auto& f_lower_logical_intrin = Op::GetAttrMap<PrimFunc>("FLowerLogicalIntrin");
if (const auto* call = op->value.as<CallNode>()) {
if (const auto* call_op = call->op.as<OpNode>()) {
PrimFunc intrin_impl = f_lower_logical_intrin.get(GetRef<Op>(call_op), NullValue<PrimFunc>());
if (intrin_impl.defined()) {
// Make inlined call to intrin_impl
CHECK(intrin_impl->params.size() == call->args.size());
Map<Var, PrimExpr> subst_map;
for (size_t i = 0; i < call->args.size(); i++) {
subst_map.Set(intrin_impl->params[i], call->args[i]);
}
Map<Var, Buffer> new_buffer_map;
for (size_t i = 0; i < call->args.size(); i++) {
const auto& param = intrin_impl->params[i];
if (const auto* var = param.as<VarNode>()) {
if (var->dtype.is_handle()) {
Var buffer_var = Downcast<Var>(param);
auto it = intrin_impl->buffer_map.find(buffer_var);
CHECK(it != intrin_impl->buffer_map.end()) << buffer_var;
if (it != intrin_impl->buffer_map.end()) {
new_buffer_map.Set((*it).second->data,
buffer_data_to_buffer_.at(Downcast<Var>(call->args[i])));
}
}
}
}

auto body = Substitute(intrin_impl->body, subst_map);
return LogicalIntrinBufferReplacer(new_buffer_map)(body);
}
}
}
return StmtMutator::VisitStmt_(op);
}

Map<Var, Buffer> buffer_data_to_buffer_;
};

namespace transform {

Pass LowerLogicalIntrin() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = LogicalIntrinMutator(f)(std::move(f->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerLogicalLayout", {});
}

TVM_REGISTER_GLOBAL("tir.transform.LowerLogicalIntrin").set_body_typed(LowerLogicalIntrin);

} // namespace transform

} // namespace tir
} // namespace tvm
Loading

0 comments on commit ca726d7

Please sign in to comment.