Skip to content

Commit

Permalink
Ported auto-tensorization code
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 7, 2022
1 parent 534205b commit 86baa31
Show file tree
Hide file tree
Showing 8 changed files with 431 additions and 4 deletions.
5 changes: 5 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1509,6 +1509,11 @@ constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_expl
/*! \brief Mark auto-unroll setting on the block. */
constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit";

/*!
* \brief Mark that the block should be further rewritten using tensorization.
*/
constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize";

/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ inline T Substitute(T input, const std::unordered_map<const VarNode*, PrimExpr>&
* children of the node
*/
TVM_DLL void PreOrderVisit(const ObjectRef& stmt_or_expr,
const std::function<bool(const ObjectRef&)>& fvisit);
const std::function<bool(const ObjectRef&)>& fvisit, bool visit_init_block=true);
} // namespace tir
} // namespace tvm

Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@
tune_tir,
)
from .tune_context import TuneContext
from . import tensor_intrin
17 changes: 17 additions & 0 deletions python/tvm/meta_schedule/tensor_intrin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from . import vnni
70 changes: 70 additions & 0 deletions python/tvm/meta_schedule/tensor_intrin/vnni.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from tvm import tir
from tvm.script import tir as T
from tvm.script.registry import register


@T.prim_func
def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
C = T.match_buffer(c, (16,), "int32", offset_factor=1)

with T.block("root"):
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
T.writes(C[0:16])
for i in T.serial(0, 16):
with T.init():
C[i] = T.int32(0)
for k in T.serial(0, 4):
with T.block("update"):
vi, vk = T.axis.remap("SR", [i, k])
C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")


@T.prim_func
def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
C = T.match_buffer(c, (16,), "int32", offset_factor=1)

with T.block("root"):
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
T.writes(C[0:16])

A_u8x4 = A.vload([0], "uint8x4")
A_i32 = T.reinterpret(A_u8x4, dtype="int32")

B_i8x64 = B.vload([0, 0], dtype="int8x64")
B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")

C[
T.ramp(T.int32(0), 1, 16)
] += T.call_llvm_pure_intrin( # Note: this is an update +=
T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
T.uint32(0),
T.int32x16(0),
T.broadcast(A_i32, 16),
B_i32x16,
dtype="int32x16",
)


tir.TensorIntrin.register(
"dot_16x1x16_uint8_int8_int32_cascadelake", dot_product_desc, dot_product_intrin
)
90 changes: 90 additions & 0 deletions src/meta_schedule/postproc/rewrite_vnni.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace meta_schedule {

using tir::BlockRV;
using tir::LoopRV;

using BlockPosition = std::tuple<String, String, String>;

class RewriteVNNINode : public PostprocNode {
public:
// Inherited from PostprocNode
void InitializeWithTuneContext(const TuneContext& context) final {}

// Inherited from PostprocNode
bool Apply(const tir::Schedule& sch) final;

void VisitAttrs(tvm::AttrVisitor* v) {}

static constexpr const char* _type_key = "meta_schedule.RewriteVNNI";
TVM_DECLARE_FINAL_OBJECT_INFO(RewriteVNNINode, PostprocNode);
};

void CollectTensorized(const tir::Schedule& sch, const String& func_name,
const tir::PrimFuncNode* func, std::vector<BlockPosition>& tasks) {
tir::PreOrderVisit(
func->body,
[&](const ObjectRef& obj) -> bool {
if (const auto* block = obj.as<tir::BlockNode>()) {
tir::StmtSRef block_sref = sch->GetSRef(block);
if (Optional<String> intrin_name =
tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize)) {
tasks.push_back(std::make_tuple(block_sref->StmtAs<tir::BlockNode>()->name_hint,
func_name, intrin_name.value()));
}
}
return true;
},
/*visit_init_block=*/false);
}

bool RewriteVNNINode::Apply(const tir::Schedule& sch) {
std::vector<BlockPosition> tasks;
for (const auto& kv : sch->mod()->functions) {
GlobalVar g_var = kv.first;
BaseFunc base_func = kv.second;
if (const tir::PrimFuncNode* prim_func = base_func.as<tir::PrimFuncNode>()) {
CollectTensorized(sch, g_var->name_hint, prim_func, tasks);
}
}
for (const BlockPosition& task : tasks) {
// Retrieve the block rv according to the task noted down before
BlockRV block_rv = sch->GetBlock(std::get<0>(task), std::get<1>(task));
String intrin_name = std::get<2>(task);
sch->Unannotate(block_rv, tir::attr::meta_schedule_auto_tensorize);
sch->Tensorize(block_rv, intrin_name);
}
return true;
}

Postproc RewriteVNNI() {
ObjectPtr<RewriteVNNINode> n = make_object<RewriteVNNINode>();
return Postproc(n);
}

TVM_REGISTER_NODE_TYPE(RewriteVNNINode);
TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteVNNI")
.set_body_typed(RewriteVNNI);

} // namespace meta_schedule
} // namespace tvm
Loading

0 comments on commit 86baa31

Please sign in to comment.